Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 140 additions & 70 deletions src/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,119 @@ async def _execute_query(self, sql: str, params: Optional[tuple] = None, databas
conn_state = f"Connection: {'acquired' if conn else 'not acquired'}"
logger.error(f"Unexpected error during query execution ({conn_state}): {e}", exc_info=True)
raise RuntimeError(f"An unexpected error occurred: {e}") from e

def _has_properly_escaped_backticks(self, identifier: str) -> bool:
"""
Check that any backticks inside quoted identifier are properly escaped (doubled).
Returns True if all backticks are properly escaped, False otherwise.

- identifier (str): quoted identifier.
"""
i = 0
while i < len(identifier):
if identifier[i] == '`':
if i + 1 < len(identifier) and identifier[i + 1] == '`':
i += 2 # Skip the escaped pair
else:
return False # Found unescaped backtick
else:
i += 1
return True

def _quote_identifier(self, identifier: str) -> str:
"""
Quote an identifier
If already quoted, returns as-is. If unquoted, wraps in backticks.

Parameters:
- identifier (str): The identifier to quote

Returns:
- str: Quoted identifier
"""
if identifier is None:
raise ValueError("Identifier cannot be None")

async def _database_exists(self, database_name: str) -> bool:
"""Checks if a database exists."""
if not database_name or not database_name.isidentifier():
logger.warning(f"_database_exists called with invalid database_name: {database_name}")
return False
if identifier.startswith('`') and identifier.endswith('`'):
# Already quoted, return as-is
return identifier
else:
# Unquoted, wrap in backticks and escape any existing backticks
escaped_content = identifier.replace('`', '``')
return f'`{escaped_content}`'


def _is_valid_identifier(self, identifier: str) -> bool:
"""
Validates MariaDB identifier that will be quoted when used in SQL.
Accepts both quoted and unquoted identifiers since all identifiers
are treated as "will be quoted when needed".

Parameters:
- identifier (str): identifier (quoted or unquoted).
"""
if not identifier:
return False

# If unquoted, quote it first, then validate as quoted identifier
if not (identifier.startswith('`') and identifier.endswith('`')):
identifier = self._quote_identifier(identifier)

# Now validate as quoted identifier
if len(identifier) <= 2:
return False

actual_name = identifier[1:-1]

# Check that any backticks inside are properly escaped (doubled)
if not self._has_properly_escaped_backticks(actual_name):
return False

# Handle escaped backticks to get the real length
escaped_name = actual_name.replace('``', '`')

# Basic length check
if len(escaped_name) > 64:
return False

# No trailing spaces allowed
if escaped_name.endswith(' '):
return False

# No null characters
if '\x00' in escaped_name:
return False

return True

def _normalize_identifier(self, identifier: str, method_name: str) -> str:
"""
Normalizes and validates an identifier for MCP methods.
Validates the identifier and returns the unquoted version.

Parameters:
- identifier (str): identifier (quoted or unquoted)
- method_name (str): name of the calling method for error messages

Returns:
- str: unquoted identifier

Raises:
- ValueError: if identifier is invalid
"""
if not self._is_valid_identifier(identifier):
error_msg = f"Invalid identifier '{identifier}' in {method_name}"
logger.error(error_msg)
raise ValueError(error_msg)

# Strip quotes if present
if identifier.startswith('`') and identifier.endswith('`'):
return identifier[1:-1].replace('``', '`')
else:
return identifier

async def _database_exists(self, database_name: str) -> bool:
"""Checks if a database exists. Expects normalized (unquoted) database name."""
sql = "SELECT SCHEMA_NAME FROM information_schema.SCHEMATA WHERE SCHEMA_NAME = %s"
try:
results = await self._execute_query(sql, params=(database_name,), database='information_schema')
Expand All @@ -176,12 +282,7 @@ async def _database_exists(self, database_name: str) -> bool:
return False

async def _table_exists(self, database_name: str, table_name: str) -> bool:
"""Checks if a table exists in the given database."""
if not database_name or not database_name.isidentifier() or \
not table_name or not table_name.isidentifier():
logger.warning(f"_table_exists called with invalid names: db='{database_name}', table='{table_name}'")
return False

"""Checks if a table exists in the given database. Expects normalized (unquoted) names."""
sql = "SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_SCHEMA = %s AND TABLE_NAME = %s"
try:
results = await self._execute_query(sql, params=(database_name, table_name), database='information_schema')
Expand All @@ -205,10 +306,7 @@ async def _is_vector_store(self, database_name: str, table_name: str) -> bool:
"""
logger.debug(f"Checking if '{database_name}.{table_name}' is a vector store.")

if not database_name or not database_name.isidentifier() or \
not table_name or not table_name.isidentifier():
logger.warning(f"_is_vector_store called with invalid names: db='{database_name}', table='{table_name}'")
return False
# Expects normalized (unquoted) names

# SQL query to verify vector store criteria
sql_query = """
Expand Down Expand Up @@ -254,9 +352,8 @@ async def list_databases(self) -> List[str]:
async def list_tables(self, database_name: str) -> List[str]:
"""Lists all tables within the specified database."""
logger.info(f"TOOL START: list_tables called. database_name={database_name}")
if not database_name or not database_name.isidentifier():
logger.warning(f"TOOL WARNING: list_tables called with invalid database_name: {database_name}")
raise ValueError(f"Invalid database name provided: {database_name}")
database_name = self._normalize_identifier(database_name, "list_tables")

sql = "SHOW TABLES"
try:
results = await self._execute_query(sql, database=database_name)
Expand All @@ -273,18 +370,15 @@ async def get_table_schema(self, database_name: str, table_name: str) -> Dict[st
for a specific table in a database.
"""
logger.info(f"TOOL START: get_table_schema called. database_name={database_name}, table_name={table_name}")
if not database_name or not database_name.isidentifier():
logger.warning(f"TOOL WARNING: get_table_schema called with invalid database_name: {database_name}")
raise ValueError(f"Invalid database name provided: {database_name}")
if not table_name or not table_name.isidentifier():
logger.warning(f"TOOL WARNING: get_table_schema called with invalid table_name: {table_name}")
raise ValueError(f"Invalid table name provided: {table_name}")

sql = f"DESCRIBE `{database_name}`.`{table_name}`"
database_name = self._normalize_identifier(database_name, "get_table_schema")
table_name = self._normalize_identifier(table_name, "get_table_schema")

sql = f"DESCRIBE {self._quote_identifier(database_name)}.{self._quote_identifier(table_name)}"
try:
schema_results = await self._execute_query(sql)
schema_info = {}
if not schema_results:
# Use normalized names for information_schema query
exists_sql = "SELECT COUNT(*) as count FROM information_schema.tables WHERE table_schema = %s AND table_name = %s"
exists_result = await self._execute_query(exists_sql, params=(database_name, table_name))
if not exists_result or exists_result[0]['count'] == 0:
Expand Down Expand Up @@ -318,12 +412,8 @@ async def get_table_schema_with_relations(self, database_name: str, table_name:
Includes all basic schema info plus foreign key relationships and referenced tables.
"""
logger.info(f"TOOL START: get_table_schema_with_relations called. database_name={database_name}, table_name={table_name}")
if not database_name or not database_name.isidentifier():
logger.warning(f"TOOL WARNING: get_table_schema_with_relations called with invalid database_name: {database_name}")
raise ValueError(f"Invalid database name provided: {database_name}")
if not table_name or not table_name.isidentifier():
logger.warning(f"TOOL WARNING: get_table_schema_with_relations called with invalid table_name: {table_name}")
raise ValueError(f"Invalid table name provided: {table_name}")
database_name = self._normalize_identifier(database_name, "get_table_schema_with_relations")
table_name = self._normalize_identifier(table_name, "get_table_schema_with_relations")

try:
# 1. Get basic schema information
Expand All @@ -348,6 +438,7 @@ async def get_table_schema_with_relations(self, database_name: str, table_name:
ORDER BY kcu.CONSTRAINT_NAME, kcu.ORDINAL_POSITION
"""

# Use normalized names for information_schema query
fk_results = await self._execute_query(fk_sql, params=(database_name, table_name))

# 3. Add foreign key information to the basic schema
Expand Down Expand Up @@ -389,9 +480,7 @@ async def execute_sql(self, sql_query: str, database_name: str, parameters: Opti
Example `parameters`: ["value1", 123] corresponding to %s placeholders in `sql_query`.
"""
logger.info(f"TOOL START: execute_sql called. database_name={database_name}, sql_query={sql_query[:100]}, parameters={parameters}")
if database_name and not database_name.isidentifier():
logger.warning(f"TOOL WARNING: execute_sql called with invalid database_name: {database_name}")
raise ValueError(f"Invalid database name provided: {database_name}")
database_name = self._normalize_identifier(database_name, "execute_sql")
param_tuple = tuple(parameters) if parameters is not None else None
try:
results = await self._execute_query(sql_query, params=param_tuple, database=database_name)
Expand All @@ -406,17 +495,15 @@ async def create_database(self, database_name: str) -> Dict[str, Any]:
Creates a new database if it doesn't exist.
"""
logger.info(f"TOOL START: create_database called for database: '{database_name}'")
if not database_name or not database_name.isidentifier():
logger.error(f"Invalid database_name for creation: '{database_name}'. Must be a valid identifier.")
raise ValueError(f"Invalid database_name for creation: '{database_name}'. Must be a valid identifier.")
database_name = self._normalize_identifier(database_name, "create_database")

# Check existence first to provide a clear message, though CREATE DATABASE IF NOT EXISTS is idempotent
if await self._database_exists(database_name):
message = f"Database '{database_name}' already exists."
logger.info(f"TOOL END: create_database. {message}")
return {"status": "exists", "message": message, "database_name": database_name}

sql = f"CREATE DATABASE IF NOT EXISTS `{database_name}`;"
sql = f"CREATE DATABASE IF NOT EXISTS {self._quote_identifier(database_name)};"

try:
await self._execute_query(sql, database=None)
Expand Down Expand Up @@ -455,12 +542,8 @@ async def create_vector_store_tool(self,
logger.info(f"TOOL START: create_vector_store called. DB: '{database_name}', Store: '{vector_store_name}', Model: '{model_name}', Embedding_Length: {embedding_length}, Distance_Requested: '{distance_function}'")

# --- Input Validation ---
if not database_name or not database_name.isidentifier():
logger.error(f"Invalid database_name: '{database_name}'. Must be a valid identifier.")
raise ValueError(f"Invalid database_name: '{database_name}'. Must be a valid identifier.")
if not vector_store_name or not vector_store_name.isidentifier():
logger.error(f"Invalid vector_store_name: '{vector_store_name}'. Must be a valid identifier.")
raise ValueError(f"Invalid vector_store_name: '{vector_store_name}'. Must be a valid identifier.")
database_name = self._normalize_identifier(database_name, "create_vector_store_tool")
vector_store_name = self._normalize_identifier(vector_store_name, "create_vector_store_tool")

if not isinstance(embedding_length, int) or embedding_length <= 0:
logger.error(f"Invalid embedding_length: {embedding_length}. Must be a positive integer.")
Expand Down Expand Up @@ -504,7 +587,7 @@ async def create_vector_store_tool(self,

# --- SQL Query for Vector Store Table Creation ---
schema_query = f"""
CREATE TABLE IF NOT EXISTS `{vector_store_name}` (
CREATE TABLE IF NOT EXISTS {self._quote_identifier(vector_store_name)} (
id VARCHAR(36) NOT NULL DEFAULT UUID_v7() PRIMARY KEY,
document TEXT NOT NULL,
embedding VECTOR({embedding_length}) NOT NULL,
Expand Down Expand Up @@ -550,9 +633,7 @@ async def list_vector_stores(self, database_name: str) -> List[str]:
logger.info(f"TOOL START: list_vector_stores called for database: '{database_name}'")

# --- Input Validation ---
if not database_name or not database_name.isidentifier():
logger.error(f"Invalid database_name: '{database_name}'. Must be a valid identifier.")
raise ValueError(f"Invalid database_name: '{database_name}'. Must be a valid identifier.")
database_name = self._normalize_identifier(database_name, "list_vector_stores")

if not await self._database_exists(database_name):
logger.warning(f"Database '{database_name}' does not exist. Cannot list vector stores.")
Expand All @@ -577,6 +658,7 @@ async def list_vector_stores(self, database_name: str) -> List[str]:
"""

try:
# Use normalized name for information_schema query
results = await self._execute_query(sql_query, params=(database_name,), database='information_schema')

store_list = [row['TABLE_NAME'] for row in results if 'TABLE_NAME' in row]
Expand Down Expand Up @@ -614,12 +696,8 @@ async def delete_vector_store(self,
logger.info(f"TOOL START: delete_vector_store called for: '{database_name}.{vector_store_name}'")

# --- Input Validation for names ---
if not database_name or not database_name.isidentifier():
logger.error(f"Invalid database_name: '{database_name}'. Must be a valid identifier.")
raise ValueError(f"Invalid database_name: '{database_name}'. Must be a valid identifier.")
if not vector_store_name or not vector_store_name.isidentifier():
logger.error(f"Invalid vector_store_name: '{vector_store_name}'. Must be a valid identifier.")
raise ValueError(f"Invalid vector_store_name: '{vector_store_name}'. Must be a valid identifier.")
database_name = self._normalize_identifier(database_name, "delete_vector_store")
vector_store_name = self._normalize_identifier(vector_store_name, "delete_vector_store")

# --- Database Existence Check ---
if not await self._database_exists(database_name):
Expand All @@ -640,7 +718,7 @@ async def delete_vector_store(self,
return {"status": "not_vector_store", "message": message}

# --- SQL Query for Deletion ---
drop_query = f"DROP TABLE IF EXISTS `{vector_store_name}`;"
drop_query = f"DROP TABLE IF EXISTS {self._quote_identifier(vector_store_name)};"

try:
await self._execute_query(drop_query, database=database_name)
Expand Down Expand Up @@ -670,12 +748,8 @@ async def insert_docs_vector_store(self, database_name: str, vector_store_name:
If metadata is not provided, an empty dict will be used for each document.
"""
import json
if not database_name or not database_name.isidentifier():
logger.error(f"Invalid database_name: '{database_name}'")
raise ValueError(f"Invalid database_name: '{database_name}'")
if not vector_store_name or not vector_store_name.isidentifier():
logger.error(f"Invalid vector_store_name: '{vector_store_name}'")
raise ValueError(f"Invalid vector_store_name: '{vector_store_name}'")
database_name = self._normalize_identifier(database_name, "insert_docs_vector_store")
vector_store_name = self._normalize_identifier(vector_store_name, "insert_docs_vector_store")
if not isinstance(documents, list) or not documents or not all(isinstance(doc, str) and doc for doc in documents):
logger.error("'documents' must be a non-empty list of non-empty strings.")
raise ValueError("'documents' must be a non-empty list of non-empty strings.")
Expand All @@ -690,7 +764,7 @@ async def insert_docs_vector_store(self, database_name: str, vector_store_name:
# Prepare metadata JSON
metadata_json = [json.dumps(m) for m in metadata]
# Prepare values for batch insert
insert_query = f"INSERT INTO `{database_name}`.`{vector_store_name}` (document, embedding, metadata) VALUES (%s, VEC_FromText(%s), %s)"
insert_query = f"INSERT INTO {self._quote_identifier(database_name)}.{self._quote_identifier(vector_store_name)} (document, embedding, metadata) VALUES (%s, VEC_FromText(%s), %s)"
inserted = 0
errors = []
for doc, emb, meta in zip(documents, embeddings, metadata_json):
Expand Down Expand Up @@ -723,12 +797,8 @@ async def search_vector_store(self, user_query: str, database_name: str, vector_
if not user_query or not isinstance(user_query, str):
logger.error("user_query must be a non-empty string.")
raise ValueError("user_query must be a non-empty string.")
if not database_name or not database_name.isidentifier():
logger.error(f"Invalid database_name: '{database_name}'")
raise ValueError(f"Invalid database_name: '{database_name}'")
if not vector_store_name or not vector_store_name.isidentifier():
logger.error(f"Invalid vector_store_name: '{vector_store_name}'")
raise ValueError(f"Invalid vector_store_name: '{vector_store_name}'")
database_name = self._normalize_identifier(database_name, "search_vector_store")
vector_store_name = self._normalize_identifier(vector_store_name, "search_vector_store")
if not isinstance(k, int) or k <= 0:
logger.error("k must be a positive integer.")
raise ValueError("k must be a positive integer.")
Expand All @@ -741,7 +811,7 @@ async def search_vector_store(self, user_query: str, database_name: str, vector_
document,
metadata,
VEC_DISTANCE_COSINE(embedding, VEC_FromText(%s)) AS distance
FROM `{database_name}`.`{vector_store_name}`
FROM {self._quote_identifier(database_name)}.{self._quote_identifier(vector_store_name)}
ORDER BY distance ASC
LIMIT %s
"""
Expand Down