Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 91 additions & 2 deletions kc/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"log/slog"
"net/http"
"net/url"
"sync"
"time"

kiteconnect "github.com/zerodha/gokiteconnect/v4"
Expand Down Expand Up @@ -121,6 +122,10 @@ type Manager struct {
Instruments *instruments.Manager
sessionManager *SessionRegistry
sessionSigner *SessionSigner

// Default session storage for authenticated session fallback
defaultSessionID string
defaultSessionMu sync.RWMutex
}

// NewManager creates a new manager with default configuration
Expand Down Expand Up @@ -223,9 +228,29 @@ func (m *Manager) logSessionRetrievedData(sessionID string) {

// GetOrCreateSession retrieves an existing Kite session or creates a new one atomically
func (m *Manager) GetOrCreateSession(mcpSessionID string) (*KiteSessionData, bool, error) {
originalSessionID := mcpSessionID

if err := m.validateSessionID(mcpSessionID); err != nil {
m.Logger.Warn("GetOrCreateSession called with empty MCP session ID")
return nil, false, err
// If the provided session ID is invalid, try to use the default session
defaultSessionID := m.GetDefaultSession()
if defaultSessionID != "" {
m.Logger.Info("Invalid session ID provided, attempting to use default session",
"invalid_session_id", mcpSessionID,
"default_session_id", defaultSessionID)
mcpSessionID = defaultSessionID

// Validate the default session ID
if err := m.validateSessionID(mcpSessionID); err != nil {
m.Logger.Warn("Default session ID is also invalid, clearing it",
"default_session_id", defaultSessionID)
m.ClearDefaultSession()
return nil, false, err
}
} else {
m.Logger.Warn("GetOrCreateSession called with invalid session ID and no default session available",
"invalid_session_id", originalSessionID)
return nil, false, err
}
}

// Use atomic GetOrCreateSessionData to eliminate TOCTOU race condition
Expand All @@ -243,10 +268,41 @@ func (m *Manager) GetOrCreateSession(mcpSessionID string) (*KiteSessionData, boo
return nil, false, err
}

// If this is a new session (unauthenticated), try to use the default session instead
if isNew {
defaultSessionID := m.GetDefaultSession()
if defaultSessionID != "" && defaultSessionID != mcpSessionID {
m.Logger.Info("New session created, attempting to use default authenticated session",
"new_session_id", mcpSessionID,
"default_session_id", defaultSessionID)

// Try to get the default session data
defaultData, defaultIsNew, err := m.sessionManager.GetOrCreateSessionData(defaultSessionID, func() any {
return m.createKiteSessionData(defaultSessionID)
})

if err == nil && !defaultIsNew {
// Successfully got the default session, use it instead
defaultKiteData, err := m.extractKiteSessionData(defaultData, defaultSessionID)
if err == nil {
m.Logger.Info("Successfully used default session for new session ID",
"original_session_id", originalSessionID,
"default_session_id", defaultSessionID)
return defaultKiteData, false, nil
}
}
}

m.logSessionCreated(mcpSessionID)
} else {
m.logSessionRetrieved(mcpSessionID)

// If we used the default session, log that information
if originalSessionID != mcpSessionID {
m.Logger.Info("Successfully used default session for invalid session ID",
"original_session_id", originalSessionID,
"default_session_id", mcpSessionID)
}
}

return kiteData, isNew, nil
Expand Down Expand Up @@ -412,6 +468,36 @@ func (m *Manager) CompleteSession(mcpSessionID, kiteRequestToken string) error {
return nil
}

// Default session management methods

// SetDefaultSession stores the authenticated session ID as the default
func (m *Manager) SetDefaultSession(sessionID string) {
m.defaultSessionMu.Lock()
defer m.defaultSessionMu.Unlock()

m.defaultSessionID = sessionID
m.Logger.Info("Set default session ID for fallback authentication", "default_session_id", sessionID)
}

// GetDefaultSession retrieves the stored default session ID
func (m *Manager) GetDefaultSession() string {
m.defaultSessionMu.RLock()
defer m.defaultSessionMu.RUnlock()

return m.defaultSessionID
}

// ClearDefaultSession clears the stored default session ID
func (m *Manager) ClearDefaultSession() {
m.defaultSessionMu.Lock()
defer m.defaultSessionMu.Unlock()

if m.defaultSessionID != "" {
m.Logger.Info("Clearing default session ID", "previous_default_session_id", m.defaultSessionID)
m.defaultSessionID = ""
}
}

// Session management utility methods

// GetActiveSessionCount returns the number of active sessions
Expand Down Expand Up @@ -520,6 +606,9 @@ func (m *Manager) HandleKiteCallback() func(w http.ResponseWriter, r *http.Reque
}

m.Logger.Info("Kite session completed successfully", "session_id", mcpSessionID)

// Store this successfully authenticated session as the default
m.SetDefaultSession(mcpSessionID)

if err := m.renderSuccessTemplate(w); err != nil {
m.Logger.Error("Template failed to load - this is a fatal error", "error", err)
Expand Down
3 changes: 3 additions & 0 deletions mcp/setup_tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ func (*LoginTool) Handler(manager *kc.Manager) server.ToolHandlerFunc {
mcpSessionID := mcpClientSession.SessionID()
manager.Logger.Info("Login tool called", "session_id", mcpSessionID)

// Clear any existing default session when starting a new login
manager.ClearDefaultSession()

// Get or create a Kite session for this MCP session
kiteSession, isNew, err := manager.GetOrCreateSession(mcpSessionID)
if err != nil {
Expand Down