diff --git a/src/Shared/ProviderConfiguration.cs b/src/Shared/ProviderConfiguration.cs index 06cb0b16..9c18d1e9 100644 --- a/src/Shared/ProviderConfiguration.cs +++ b/src/Shared/ProviderConfiguration.cs @@ -36,11 +36,11 @@ internal ProviderConfiguration() internal static ProviderConfiguration ProviderConfigurationForSessionState(NameValueCollection config) { ProviderConfiguration configuration = new ProviderConfiguration(config); - + configuration.ThrowOnError = GetBoolSettings(config, "throwOnError", true); int retryTimeoutInMilliSec = GetIntSettings(config, "retryTimeoutInMilliseconds", 5000); configuration.RetryTimeout = new TimeSpan(0, 0, 0, 0, retryTimeoutInMilliSec); - + // Get request timeout from config HttpRuntimeSection httpRuntimeSection = ConfigurationManager.GetSection("system.web/httpRuntime") as HttpRuntimeSection; configuration.RequestTimeout = httpRuntimeSection.ExecutionTimeout; @@ -57,10 +57,10 @@ internal static ProviderConfiguration ProviderConfigurationForSessionState(NameV internal static ProviderConfiguration ProviderConfigurationForOutputCache(NameValueCollection config) { ProviderConfiguration configuration = new ProviderConfiguration(config); - + // No retry login for output cache provider configuration.RetryTimeout = TimeSpan.Zero; - + // Session state specific attribute which are not applicable to output cache configuration.ThrowOnError = true; configuration.RequestTimeout = TimeSpan.Zero; @@ -81,7 +81,7 @@ private ProviderConfiguration(NameValueCollection config) Port = GetIntSettings(config, "port", 0); AccessKey = GetStringSettings(config, "accessKey", null); UseSsl = GetBoolSettings(config, "ssl", true); - + // All below parameters are only fetched from web.config DatabaseId = GetIntSettings(config, "databaseId", 0); ApplicationName = GetStringSettings(config, "applicationName", null); @@ -129,6 +129,12 @@ private static string GetStringSettings(NameValueCollection config, string attrN return defaultVal; } + string connectionStringValue = GetFromConnectionString(literalValue); + if (!string.IsNullOrEmpty(connectionStringValue)) + { + return connectionStringValue; + } + string appSettingsValue = GetFromAppSetting(literalValue); if (!string.IsNullOrEmpty(appSettingsValue)) { @@ -205,6 +211,20 @@ private static string GetFromAppSetting(string attrName) return null; } + private static string GetFromConnectionString(string connectionStringName) + { + if (!string.IsNullOrEmpty(connectionStringName)) + { + var connectionString = ConfigurationManager.ConnectionStrings[connectionStringName]; + + if (connectionString != null) + { + return connectionString.ConnectionString; + } + } + return null; + } + // Reads string value from web.config session state section private static string GetFromConfig(NameValueCollection config, string attrName) { @@ -220,12 +240,12 @@ internal static void EnableLoggingIfParametersAvailable(NameValueCollection conf { string LoggingClassName = GetStringSettings(config, "loggingClassName", null); string LoggingMethodName = GetStringSettings(config, "loggingMethodName", null); - + if( !string.IsNullOrEmpty(LoggingClassName) && !string.IsNullOrEmpty(LoggingMethodName) ) { // Find 'Type' that is same as fully qualified class name if not found than also don't throw error and ignore case while searching Type LoggingClass = Type.GetType(LoggingClassName, throwOnError: false, ignoreCase: true); - + if (LoggingClass == null) { // If class name is not assembly qualified name than look for class in all assemblies one by one @@ -264,7 +284,7 @@ internal static Type GetLoggingClass(string LoggingClassName) if (LoggingClass == null) { // If class name is not assembly qualified name and it also doesn't contain namespace (it is just class name) than - // try to use assembly name as namespace and try to load class from all assemblies one by one + // try to use assembly name as namespace and try to load class from all assemblies one by one LoggingClass = a.GetType(a.GetName().Name + "." + LoggingClassName, throwOnError: false, ignoreCase: true); } if (LoggingClass != null) diff --git a/src/Shared/StackExchangeClientConnection.cs b/src/Shared/StackExchangeClientConnection.cs index 87ba215f..69134e22 100644 --- a/src/Shared/StackExchangeClientConnection.cs +++ b/src/Shared/StackExchangeClientConnection.cs @@ -5,6 +5,7 @@ using System; using System.Diagnostics; +using System.Net; using System.Web.SessionState; using StackExchange.Redis; @@ -13,19 +14,24 @@ namespace Microsoft.Web.Redis internal class StackExchangeClientConnection : IRedisClientConnection { - ConnectionMultiplexer redisMultiplexer; - IDatabase connection; - ProviderConfiguration configuration; + ConnectionMultiplexer _redisMultiplexer; + IDatabase _connection; + ProviderConfiguration _configuration; public StackExchangeClientConnection(ProviderConfiguration configuration) { - this.configuration = configuration; + _configuration = configuration; ConfigurationOptions configOption; // If connection string is given then use it otherwise use individual options if (!string.IsNullOrEmpty(configuration.ConnectionString)) { configOption = ConfigurationOptions.Parse(configuration.ConnectionString); + + if (!string.IsNullOrEmpty(configOption.ServiceName)) + { + ModifyEndpointsForSentinelConfiguration(configOption); + } } else { @@ -52,20 +58,43 @@ public StackExchangeClientConnection(ProviderConfiguration configuration) configOption.SyncTimeout = configuration.OperationTimeoutInMilliSec; } } - if (LogUtility.logger == null) + + _redisMultiplexer = LogUtility.logger == null ? ConnectionMultiplexer.Connect(configOption) : ConnectionMultiplexer.Connect(configOption, LogUtility.logger); + + _connection = _redisMultiplexer.GetDatabase(configuration.DatabaseId); + } + + private static void ModifyEndpointsForSentinelConfiguration(ConfigurationOptions configOption) + { + var sentinelConfiguration = new ConfigurationOptions { - redisMultiplexer = ConnectionMultiplexer.Connect(configOption); - } - else + CommandMap = CommandMap.Sentinel, + TieBreaker = "", + ServiceName = configOption.ServiceName, + SyncTimeout = configOption.SyncTimeout + }; + + EndPoint masterEndPoint = null; + + foreach (var endpoint in configOption.EndPoints) { - redisMultiplexer = ConnectionMultiplexer.Connect(configOption, LogUtility.logger); + sentinelConfiguration.EndPoints.Add(endpoint); + var sentinelConnection = ConnectionMultiplexer.Connect(sentinelConfiguration); + masterEndPoint = sentinelConnection.GetServer(endpoint).SentinelGetMasterAddressByName(sentinelConfiguration.ServiceName); + + if (masterEndPoint != null) + { + break; + } } - this.connection = redisMultiplexer.GetDatabase(configuration.DatabaseId); + + configOption.EndPoints.Clear(); + configOption.EndPoints.Add(masterEndPoint); } public IDatabase RealConnection { - get { return connection; } + get { return _connection; } } public void Open() @@ -73,21 +102,21 @@ public void Open() public void Close() { - redisMultiplexer.Close(); + _redisMultiplexer.Close(); } public bool Expiry(string key, int timeInSeconds) { TimeSpan timeSpan = new TimeSpan(0, 0, timeInSeconds); RedisKey redisKey = key; - return (bool)RetryLogic(() => connection.KeyExpire(redisKey,timeSpan)); + return (bool)RetryLogic(() => _connection.KeyExpire(redisKey,timeSpan)); } public object Eval(string script, string[] keyArgs, object[] valueArgs) { RedisKey[] redisKeyArgs = new RedisKey[keyArgs.Length]; RedisValue[] redisValueArgs = new RedisValue[valueArgs.Length]; - + int i = 0; foreach (string key in keyArgs) { @@ -110,7 +139,7 @@ public object Eval(string script, string[] keyArgs, object[] valueArgs) } i++; } - return RetryLogic(() => connection.ScriptEvaluate(script, redisKeyArgs, redisValueArgs)); + return RetryLogic(() => _connection.ScriptEvaluate(script, redisKeyArgs, redisValueArgs)); } private object RetryForScriptNotFound(Func redisOperation) @@ -146,18 +175,16 @@ private object RetryLogic(Func redisOperation) catch (Exception) { TimeSpan passedTime = DateTime.Now - startTime; - if (configuration.RetryTimeout < passedTime) + if (_configuration.RetryTimeout < passedTime) { throw; } - else + + var remainingTimeout = (int)(_configuration.RetryTimeout.TotalMilliseconds - passedTime.TotalMilliseconds); + // if remaining time is less than 1 sec than wait only for that much time and than give a last try + if (remainingTimeout < timeToSleepBeforeRetryInMiliseconds) { - int remainingTimeout = (int)(configuration.RetryTimeout.TotalMilliseconds - passedTime.TotalMilliseconds); - // if remaining time is less than 1 sec than wait only for that much time and than give a last try - if (remainingTimeout < timeToSleepBeforeRetryInMiliseconds) - { - timeToSleepBeforeRetryInMiliseconds = remainingTimeout; - } + timeToSleepBeforeRetryInMiliseconds = remainingTimeout; } // First time try after 20 msec after that try after 1 second @@ -176,7 +203,7 @@ public int GetSessionTimeout(object rowDataFromRedis) int sessionTimeout = (int)lockScriptReturnValueArray[2]; if (sessionTimeout == -1) { - sessionTimeout = (int) configuration.SessionTimeout.TotalSeconds; + sessionTimeout = (int) _configuration.SessionTimeout.TotalSeconds; } // converting seconds to minutes sessionTimeout = sessionTimeout / 60; @@ -194,7 +221,7 @@ public bool IsLocked(object rowDataFromRedis) public string GetLockId(object rowDataFromRedis) { - return StackExchangeClientConnection.GetLockIdStatic(rowDataFromRedis); + return GetLockIdStatic(rowDataFromRedis); } internal static string GetLockIdStatic(object rowDataFromRedis) @@ -207,7 +234,7 @@ internal static string GetLockIdStatic(object rowDataFromRedis) public ISessionStateItemCollection GetSessionData(object rowDataFromRedis) { - return StackExchangeClientConnection.GetSessionDataStatic(rowDataFromRedis); + return GetSessionDataStatic(rowDataFromRedis); } internal static ISessionStateItemCollection GetSessionDataStatic(object rowDataFromRedis) @@ -220,7 +247,7 @@ internal static ISessionStateItemCollection GetSessionDataStatic(object rowDataF if (lockScriptReturnValueArray.Length > 1 && lockScriptReturnValueArray[1] != null) { RedisResult[] data = (RedisResult[])lockScriptReturnValueArray[1]; - + // LUA script returns data as object array so keys and values are store one after another // This list has to be even because it contains pair of as {key, value, key, value} if (data != null && data.Length != 0 && data.Length % 2 == 0) @@ -247,23 +274,23 @@ public void Set(string key, byte[] data, DateTime utcExpiry) RedisKey redisKey = key; RedisValue redisValue = data; TimeSpan timeSpanForExpiry = utcExpiry - DateTime.UtcNow; - connection.StringSet(redisKey, redisValue, timeSpanForExpiry); + _connection.StringSet(redisKey, redisValue, timeSpanForExpiry); } public byte[] Get(string key) { RedisKey redisKey = key; - RedisValue redisValue = connection.StringGet(redisKey); - return (byte[]) redisValue; + RedisValue redisValue = _connection.StringGet(redisKey); + return redisValue; } public void Remove(string key) { RedisKey redisKey = key; - connection.KeyDelete(redisKey); + _connection.KeyDelete(redisKey); } - public byte[] GetOutputCacheDataFromResult(object rowDataFromRedis) + public byte[] GetOutputCacheDataFromResult(object rowDataFromRedis) { RedisResult rowDataAsRedisResult = (RedisResult)rowDataFromRedis; return (byte[]) rowDataAsRedisResult;