diff --git a/.env.example b/.env.example index 5981dc2..f14a068 100644 --- a/.env.example +++ b/.env.example @@ -10,4 +10,6 @@ REDIS_SSL_CERTFILE=/path/to/cert.pem REDIS_CERT_REQS=required REDIS_CA_CERTS=/path/to/ca_certs.pem REDIS_CLUSTER_MODE=False +ALLOW_DB_SWITCH=false +BLOCKED_DBS=1,2,15 MCP_TRANSPORT=stdio \ No newline at end of file diff --git a/smithery.yaml b/smithery.yaml index 48c5bdf..1560670 100644 --- a/smithery.yaml +++ b/smithery.yaml @@ -47,6 +47,14 @@ startCommand: type: string default: "" description: Path to trusted CA certificates file + allowDbSwitch: + type: boolean + default: false + description: Allow database switching + blockedDbs: + type: string + default: "" + description: Comma-separated list of database numbers to block access to (e.g. "1,2,15") commandFunction: # A JS function that produces the CLI command based on the given config to start the MCP on stdio. |- @@ -63,7 +71,9 @@ startCommand: REDIS_SSL_KEYFILE: config.redisSSLKeyfile, REDIS_SSL_CERTFILE: config.redisSSLCertfile, REDIS_CERT_REQS: config.redisCertReqs, - REDIS_CA_CERTS: config.redisCACerts + REDIS_CA_CERTS: config.redisCACerts, + ALLOW_DB_SWITCH: config.allowDbSwitch ? 'true' : 'false', + BLOCKED_DBS: config.blockedDbs } }) exampleConfig: @@ -77,3 +87,5 @@ startCommand: redisSSLCertfile: "" redisCertReqs: required redisCACerts: "" + allowDbSwitch: false + blockedDbs: "" diff --git a/src/common/config.py b/src/common/config.py index 0ef7a3f..6c20989 100644 --- a/src/common/config.py +++ b/src/common/config.py @@ -10,14 +10,16 @@ "port": int(os.getenv('REDIS_PORT',6379)), "username": os.getenv('REDIS_USERNAME', None), "password": os.getenv('REDIS_PWD',''), - "ssl": os.getenv('REDIS_SSL', False) in ('true', '1', 't'), + "ssl": os.getenv('REDIS_SSL', 'false').lower() in ('true', '1', 't'), "ssl_ca_path": os.getenv('REDIS_SSL_CA_PATH', None), "ssl_keyfile": os.getenv('REDIS_SSL_KEYFILE', None), "ssl_certfile": os.getenv('REDIS_SSL_CERTFILE', None), "ssl_cert_reqs": os.getenv('REDIS_SSL_CERT_REQS', 'required'), "ssl_ca_certs": os.getenv('REDIS_SSL_CA_CERTS', None), - "cluster_mode": os.getenv('REDIS_CLUSTER_MODE', False) in ('true', '1', 't'), - "db": int(os.getenv('REDIS_DB', 0))} + "cluster_mode": os.getenv('REDIS_CLUSTER_MODE', 'false').lower() in ('true', '1', 't'), + "db": int(os.getenv('REDIS_DB', 0)), + "allow_db_switch": os.getenv('ALLOW_DB_SWITCH', 'false').lower() in ('true', '1', 't'), + "blocked_dbs": [int(db.strip()) for db in os.getenv('BLOCKED_DBS', '').split(',') if db.strip().isdigit()]} def generate_redis_uri(): @@ -58,4 +60,27 @@ def generate_redis_uri(): if query_params: base_uri += "?" + urllib.parse.urlencode(query_params) - return base_uri \ No newline at end of file + return base_uri + + +def is_database_blocked(db: int) -> bool: + """ + Check if a database number is in the blocked list. + + Args: + db (int): Database number to check + + Returns: + bool: True if database is blocked, False otherwise + """ + return db in REDIS_CFG.get("blocked_dbs", []) + + +def get_blocked_databases() -> list: + """ + Get the list of blocked database numbers. + + Returns: + list: List of blocked database numbers + """ + return REDIS_CFG.get("blocked_dbs", []) \ No newline at end of file diff --git a/src/common/connection.py b/src/common/connection.py index 298966d..bb14962 100644 --- a/src/common/connection.py +++ b/src/common/connection.py @@ -3,77 +3,125 @@ import redis from redis import Redis from redis.cluster import RedisCluster -from typing import Optional, Type, Union -from common.config import REDIS_CFG - -from common.config import generate_redis_uri +from redis.exceptions import RedisError +from typing import Optional, Type, Union, Dict, Any +from common.config import REDIS_CFG, is_database_blocked class RedisConnectionManager: _instance: Optional[Redis] = None + _DEFAULT_MAX_CONNECTIONS = 10 @classmethod - def get_connection(cls, decode_responses=True) -> Redis: - if cls._instance is None: - try: - if REDIS_CFG["cluster_mode"]: - redis_class: Type[Union[Redis, RedisCluster]] = redis.cluster.RedisCluster - connection_params = { - "host": REDIS_CFG["host"], - "port": REDIS_CFG["port"], - "username": REDIS_CFG["username"], - "password": REDIS_CFG["password"], - "ssl": REDIS_CFG["ssl"], - "ssl_ca_path": REDIS_CFG["ssl_ca_path"], - "ssl_keyfile": REDIS_CFG["ssl_keyfile"], - "ssl_certfile": REDIS_CFG["ssl_certfile"], - "ssl_cert_reqs": REDIS_CFG["ssl_cert_reqs"], - "ssl_ca_certs": REDIS_CFG["ssl_ca_certs"], - "decode_responses": decode_responses, - "lib_name": f"redis-py(mcp-server_v{__version__})", - "max_connections_per_node": 10 - } - else: - redis_class: Type[Union[Redis, RedisCluster]] = redis.Redis - connection_params = { - "host": REDIS_CFG["host"], - "port": REDIS_CFG["port"], - "db": REDIS_CFG["db"], - "username": REDIS_CFG["username"], - "password": REDIS_CFG["password"], - "ssl": REDIS_CFG["ssl"], - "ssl_ca_path": REDIS_CFG["ssl_ca_path"], - "ssl_keyfile": REDIS_CFG["ssl_keyfile"], - "ssl_certfile": REDIS_CFG["ssl_certfile"], - "ssl_cert_reqs": REDIS_CFG["ssl_cert_reqs"], - "ssl_ca_certs": REDIS_CFG["ssl_ca_certs"], - "decode_responses": decode_responses, - "lib_name": f"redis-py(mcp-server_v{__version__})", - "max_connections": 10 - } - - cls._instance = redis_class(**connection_params) + def _build_connection_params(cls, decode_responses: bool = True, db: Optional[int] = None) -> Dict[str, Any]: + """Build connection parameters from configuration.""" + params = { + "host": REDIS_CFG["host"], + "port": REDIS_CFG["port"], + "username": REDIS_CFG["username"], + "password": REDIS_CFG["password"], + "ssl": REDIS_CFG["ssl"], + "ssl_ca_path": REDIS_CFG["ssl_ca_path"], + "ssl_keyfile": REDIS_CFG["ssl_keyfile"], + "ssl_certfile": REDIS_CFG["ssl_certfile"], + "ssl_cert_reqs": REDIS_CFG["ssl_cert_reqs"], + "ssl_ca_certs": REDIS_CFG["ssl_ca_certs"], + "decode_responses": decode_responses, + "lib_name": f"redis-py(mcp-server_v{__version__})", + } + + # Handle database parameter + if REDIS_CFG["cluster_mode"]: + if db is not None: + raise RedisError("Database switching not supported in cluster mode") + params["max_connections_per_node"] = cls._DEFAULT_MAX_CONNECTIONS + else: + params["db"] = db if db is not None else REDIS_CFG["db"] + params["max_connections"] = cls._DEFAULT_MAX_CONNECTIONS + + return params - except redis.exceptions.ConnectionError: - print("Failed to connect to Redis server", file=sys.stderr) - raise - except redis.exceptions.AuthenticationError: - print("Authentication failed", file=sys.stderr) - raise - except redis.exceptions.TimeoutError: - print("Connection timed out", file=sys.stderr) - raise - except redis.exceptions.ResponseError as e: - print(f"Response error: {e}", file=sys.stderr) - raise - except redis.exceptions.RedisError as e: - print(f"Redis error: {e}", file=sys.stderr) - raise - except redis.exceptions.ClusterError as e: - print(f"Redis Cluster error: {e}", file=sys.stderr) - raise - except Exception as e: + @classmethod + def _create_connection(cls, decode_responses: bool = True, db: Optional[int] = None) -> Redis: + """Create a new Redis connection with the given parameters.""" + try: + connection_params = cls._build_connection_params(decode_responses, db) + + if REDIS_CFG["cluster_mode"]: + redis_class: Type[Union[Redis, RedisCluster]] = redis.cluster.RedisCluster + else: + redis_class: Type[Union[Redis, RedisCluster]] = redis.Redis + + return redis_class(**connection_params) + + except redis.exceptions.ConnectionError: + print("Failed to connect to Redis server", file=sys.stderr) + raise + except redis.exceptions.AuthenticationError: + print("Authentication failed", file=sys.stderr) + raise + except redis.exceptions.TimeoutError: + print("Connection timed out", file=sys.stderr) + raise + except redis.exceptions.ResponseError as e: + print(f"Response error: {e}", file=sys.stderr) + raise + except redis.exceptions.RedisError as e: + print(f"Redis error: {e}", file=sys.stderr) + raise + except redis.exceptions.ClusterError as e: + print(f"Redis Cluster error: {e}", file=sys.stderr) + raise + except Exception as e: + if db is not None: + raise RedisError(f"Error connecting to database {db}: {str(e)}") + else: print(f"Unexpected error: {e}", file=sys.stderr) raise - return cls._instance + @classmethod + def get_connection(cls, decode_responses: bool = True, db: Optional[int] = None, use_singleton: bool = True) -> Redis: + """ + Get a Redis connection. + + Args: + decode_responses (bool): Whether to decode responses + db (Optional[int]): Database number to connect to (None uses config default) + use_singleton (bool): Whether to use singleton pattern (True) or create new connection (False) + + Returns: + Redis: Redis connection instance + + Raises: + RedisError: If cluster mode is enabled and db is specified, or connection fails + """ + # Check if the specified database is blocked + if db is not None and is_database_blocked(db): + raise RedisError(f"Access to database {db} is blocked") + + if use_singleton and db is None: + # Singleton behavior for default database + if cls._instance is None: + cls._instance = cls._create_connection(decode_responses) + return cls._instance + else: + # Create new connection for specific database or when singleton is disabled + return cls._create_connection(decode_responses, db) + + @classmethod + def get_connection_for_db(cls, db: int, decode_responses: bool = True) -> Redis: + """ + Get a Redis connection for a specific database. + This creates a new connection rather than using the singleton. + + Args: + db (int): Database number to connect to + decode_responses (bool): Whether to decode responses + + Returns: + Redis: Redis connection instance for the specified database + + Raises: + RedisError: If cluster mode is enabled or connection fails + """ + return cls.get_connection(decode_responses=decode_responses, db=db, use_singleton=False) diff --git a/src/tools/server_management.py b/src/tools/server_management.py index 2f0d254..5f229fd 100644 --- a/src/tools/server_management.py +++ b/src/tools/server_management.py @@ -1,6 +1,7 @@ from common.connection import RedisConnectionManager from redis.exceptions import RedisError from common.server import mcp +from common.config import REDIS_CFG, is_database_blocked, get_blocked_databases @mcp.tool() async def dbsize() -> int: @@ -39,4 +40,116 @@ async def client_list() -> list: clients = r.client_list() return clients except RedisError as e: - return f"Error retrieving client list: {str(e)}" \ No newline at end of file + return f"Error retrieving client list: {str(e)}" + + +@mcp.tool() +async def switch_database(db: int) -> str: + """ + Switch to a different Redis database. + + Args: + db (int): Database number (0-15 for most Redis configurations) + + Returns: + str: Confirmation message or error message + + Note: + Database switching must be enabled via ALLOW_DB_SWITCH environment variable. + Set ALLOW_DB_SWITCH=true to enable this feature. + """ + # Check if database switching is allowed + if not REDIS_CFG["allow_db_switch"]: + return "Error: Database switching is disabled. Set ALLOW_DB_SWITCH=true to enable this feature." + + # Check if the target database is blocked + if is_database_blocked(db): + blocked_dbs = get_blocked_databases() + return f"Error: Access to database {db} is blocked. Blocked databases: {blocked_dbs}" + + try: + r = RedisConnectionManager.get_connection() + r.execute_command("SELECT", db) + return f"Successfully switched to database {db}" + except RedisError as e: + return f"Error switching to database {db}: {str(e)}" + + +@mcp.tool() +async def get_current_database() -> dict: + """ + Get information about the currently selected database. + + Returns: + Dict containing current database info and key count + """ + try: + r = RedisConnectionManager.get_connection() + # Get current database info + info = r.info('keyspace') + + # Try to determine current database by checking connection + try: + # This will show us which DB we're connected to + config_info = r.config_get('databases') + total_dbs = int(config_info.get('databases', 16)) + except: + total_dbs = 16 # Default Redis database count + + # Get current database number (Redis doesn't have a direct command for this) + # We'll use a workaround by checking which DB has activity or using SELECT + current_info = r.execute_command("INFO", "keyspace") + + return { + "keyspace_info": info, + "total_databases": total_dbs, + "current_keyspace": current_info, + "note": "Use switch_database(db) to change database" + } + except RedisError as e: + return {"error": f"Error getting database info: {str(e)}"} + + +@mcp.tool() +async def list_all_databases() -> dict: + """ + List all databases and their key counts. + + Returns: + Dict containing information about all databases + """ + try: + databases_info = {} + + # Try databases 0-15 (standard Redis range) + for db_num in range(16): + try: + # Create connection for specific database + r = RedisConnectionManager.get_connection_for_db(db_num) + key_count = r.dbsize() + + if key_count > 0: # Only include databases with keys + databases_info[f"db{db_num}"] = { + "database": db_num, + "keys": key_count, + "has_data": True + } + else: + databases_info[f"db{db_num}"] = { + "database": db_num, + "keys": 0, + "has_data": False + } + + except RedisError: + # Skip databases that can't be accessed + continue + + return { + "databases": databases_info, + "total_found": len([db for db in databases_info.values() if db["has_data"]]), + "note": "Use switch_database(db_number) to switch between databases" + } + + except Exception as e: + return {"error": f"Error listing databases: {str(e)}"} \ No newline at end of file