diff --git a/src/SharpLinks.cs b/src/SharpLinks.cs index baa1321..7600012 100644 --- a/src/SharpLinks.cs +++ b/src/SharpLinks.cs @@ -19,6 +19,7 @@ using System.DirectoryServices.ActiveDirectory; using System.IO; using System.Linq; +using System.Reflection; using System.Security.Principal; using System.Threading.Tasks; using Microsoft.Extensions.Logging; @@ -29,8 +30,7 @@ using SharpHoundCommonLib.Processors; using Timer = System.Timers.Timer; -namespace Sharphound -{ +namespace Sharphound { internal class SharpLinks : Links { /// /// Define methods that SharpHound executes as part of operation pipeline. @@ -62,7 +62,8 @@ public IContext Initialize(IContext context, LdapConfig options) { if (!context.LDAPUtils.GetDomain(out var d)) { context.Logger.LogCritical("unable to get current domain"); context.Flags.IsFaulted = true; - } else { + } + else { context.DomainName = d.Name; context.Logger.LogInformation("Resolved current domain to {Domain}", d.Name); } @@ -90,7 +91,8 @@ public IContext Initialize(IContext context, LdapConfig options) { } File.Delete(filename); - } catch (Exception e) { + } + catch (Exception e) { context.Logger.LogCritical(e, "unable to write to target directory"); context.Flags.IsFaulted = true; } @@ -137,26 +139,53 @@ public IContext InitCommonLib(IContext context) { context.Logger.LogTrace("Getting cache path"); var path = context.GetCachePath(); context.Logger.LogTrace("Cache Path: {Path}", path); + + var version = Assembly.GetExecutingAssembly().GetName().Version; Cache cache; if (!File.Exists(path)) { context.Logger.LogTrace("Cache file does not exist"); - cache = null; - } else + cache = Cache.CreateNewCache(version); + } + else if (context.Flags.InvalidateCache) { + context.Logger.LogTrace($"Skipping cache load per option {nameof(Options.RebuildCache)}"); + cache = Cache.CreateNewCache(version); + } + else { try { context.Logger.LogTrace("Loading cache from disk"); var json = File.ReadAllText(path); cache = JsonConvert.DeserializeObject(json, CacheContractResolver.Settings); context.Logger.LogInformation("Loaded cache with stats: {stats}", cache?.GetCacheStats()); - } catch (Exception e) { + } + catch (Exception e) { context.Logger.LogError("Error loading cache: {exception}, creating new", e); - cache = null; + cache = Cache.CreateNewCache(version); + } + + if (CacheNeedsInvalidation(cache, version)) { + context.Logger.LogInformation("Old cache found, ignoring"); + cache = Cache.CreateNewCache(version); } + } CommonLib.InitializeCommonLib(context.Logger, cache); context.Logger.LogTrace("Exiting InitCommonLib"); return context; } + private bool CacheNeedsInvalidation(Cache cache, Version version) { + var threshold = DateTime.Now.Subtract(TimeSpan.FromDays(30)); + if (cache.CacheCreationDate < threshold) { + return true; + } + + if (cache.CacheCreationVersion == null || version > cache.CacheCreationVersion) { + return true; + } + + return false; + } + public async Task GetDomainsForEnumeration(IContext context) { context.Logger.LogTrace("Entering GetDomainsForEnumeration"); if (context.Flags.RecurseDomains) { @@ -178,7 +207,8 @@ public async Task GetDomainsForEnumeration(IContext context) { Forest forest; try { forest = dObj.Forest; - } catch (Exception e) { + } + catch (Exception e) { context.Logger.LogError("Unable to get forest object for SearchForest: {Message}", e.Message); context.Flags.IsFaulted = true; return context; @@ -224,7 +254,8 @@ public async Task GetDomainsForEnumeration(IContext context) { DomainSid = sid } }; - } else { + } + else { context.Domains = new[] { new EnumerationDomain { Name = domain, @@ -315,11 +346,11 @@ public IContext Finish(IContext context) { } public IContext SaveCacheFile(IContext context) { - if (context.Flags.MemCache) - return context; - // 15. Program exit started. Save the cache file + // if (context.Flags.MemCache) + // return context; + // // 15. Program exit started. Save the cache file var cache = Cache.GetCacheInstance(); - context.Logger.LogInformation("Saving cache with stats: {stats}", cache.GetCacheStats()); + // context.Logger.LogInformation("Saving cache with stats: {stats}", cache.GetCacheStats()); var serialized = JsonConvert.SerializeObject(cache, CacheContractResolver.Settings); using var stream = new StreamWriter(context.GetCachePath());