Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 61 additions & 5 deletions apps/workspace-engine/pkg/db/workspaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,25 @@ func GetAllWorkspaceIDs(ctx context.Context) ([]string, error) {
}

type WorkspaceSnapshot struct {
WorkspaceID string
Path string
Timestamp time.Time
Partition int32
NumPartitions int32
Offset int64
}

const WORKSPACE_SNAPSHOT_SELECT_QUERY = `
SELECT path, timestamp, partition, num_partitions FROM workspace_snapshot WHERE workspace_id = $1 ORDER BY timestamp DESC LIMIT 1
SELECT
workspace_id,
path,
timestamp,
partition,
num_partitions,
offset
FROM workspace_snapshot
WHERE workspace_id = $1
ORDER BY offset DESC LIMIT 1
`

func GetWorkspaceSnapshot(ctx context.Context, workspaceID string) (*WorkspaceSnapshot, error) {
Expand All @@ -108,10 +119,12 @@ func GetWorkspaceSnapshot(ctx context.Context, workspaceID string) (*WorkspaceSn

workspaceSnapshot := &WorkspaceSnapshot{}
err = db.QueryRow(ctx, WORKSPACE_SNAPSHOT_SELECT_QUERY, workspaceID).Scan(
&workspaceSnapshot.WorkspaceID,
&workspaceSnapshot.Path,
&workspaceSnapshot.Timestamp,
&workspaceSnapshot.Partition,
&workspaceSnapshot.NumPartitions,
&workspaceSnapshot.Offset,
)
if err != nil {
if err == pgx.ErrNoRows {
Expand All @@ -122,19 +135,62 @@ func GetWorkspaceSnapshot(ctx context.Context, workspaceID string) (*WorkspaceSn
return workspaceSnapshot, nil
}

func GetLatestWorkspaceSnapshots(ctx context.Context, workspaceIDs []string) (map[string]*WorkspaceSnapshot, error) {
if len(workspaceIDs) == 0 {
return nil, nil
}

db, err := GetDB(ctx)
if err != nil {
return nil, err
}
defer db.Release()

const query = `
SELECT DISTINCT ON (workspace_id) workspace_id, path, timestamp, partition, num_partitions, offset
FROM workspace_snapshot
WHERE workspace_id = ANY($1)
ORDER BY workspace_id, offset DESC
`
rows, err := db.Query(ctx, query, workspaceIDs)
if err != nil {
return nil, err
}
defer rows.Close()

var snapshots []*WorkspaceSnapshot
for rows.Next() {
var snapshot WorkspaceSnapshot
err := rows.Scan(&snapshot.WorkspaceID, &snapshot.Path, &snapshot.Timestamp, &snapshot.Partition, &snapshot.NumPartitions, &snapshot.Offset)
if err != nil {
return nil, err
}
snapshots = append(snapshots, &snapshot)
}
if err := rows.Err(); err != nil {
return nil, err
}

snapshotMap := make(map[string]*WorkspaceSnapshot)
for _, snapshot := range snapshots {
snapshotMap[snapshot.WorkspaceID] = snapshot
}
return snapshotMap, nil
}

const WORKSPACE_SNAPSHOT_INSERT_QUERY = `
INSERT INTO workspace_snapshot (workspace_id, path, timestamp, partition, num_partitions)
VALUES ($1, $2, $3, $4, $5)
INSERT INTO workspace_snapshot (workspace_id, path, timestamp, partition, num_partitions, offset)
VALUES ($1, $2, $3, $4, $5, $6)
`

func WriteWorkspaceSnapshot(ctx context.Context, workspaceID string, snapshot *WorkspaceSnapshot) error {
func WriteWorkspaceSnapshot(ctx context.Context, snapshot *WorkspaceSnapshot) error {
db, err := GetDB(ctx)
if err != nil {
return err
}
defer db.Release()

_, err = db.Exec(ctx, WORKSPACE_SNAPSHOT_INSERT_QUERY, workspaceID, snapshot.Path, snapshot.Timestamp, snapshot.Partition, snapshot.NumPartitions)
_, err = db.Exec(ctx, WORKSPACE_SNAPSHOT_INSERT_QUERY, snapshot.WorkspaceID, snapshot.Path, snapshot.Timestamp, snapshot.Partition, snapshot.NumPartitions, snapshot.Offset)
if err != nil {
return err
}
Expand Down
15 changes: 14 additions & 1 deletion apps/workspace-engine/pkg/events/handler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,14 @@ func NewEventListener(handlers HandlerRegistry) *EventListener {
return &EventListener{handlers: handlers}
}

type OffsetTracker struct {
LastCommittedOffset int64
LastWorkspaceOffset int64
MessageOffset int64
}

// ListenAndRoute processes incoming Kafka messages and routes them to the appropriate handler
func (el *EventListener) ListenAndRoute(ctx context.Context, msg *kafka.Message) (*workspace.Workspace, error) {
func (el *EventListener) ListenAndRoute(ctx context.Context, msg *kafka.Message, offsetTracker OffsetTracker) (*workspace.Workspace, error) {
ctx, span := tracer.Start(ctx, "ListenAndRoute",
trace.WithAttributes(
attribute.String("kafka.topic", *msg.TopicPartition.Topic),
Expand Down Expand Up @@ -149,6 +155,13 @@ func (el *EventListener) ListenAndRoute(ctx context.Context, msg *kafka.Message)
return nil, fmt.Errorf("workspace not found: %s", rawEvent.WorkspaceID)
}

isReplay := offsetTracker.MessageOffset <= offsetTracker.LastCommittedOffset
ws.Store().SetIsReplay(isReplay)

if offsetTracker.MessageOffset <= offsetTracker.LastWorkspaceOffset {
return ws, nil
}

ctx = changeset.WithChangeSet(ctx, changeSet)

if err := handler(ctx, ws, rawEvent); err != nil {
Expand Down
46 changes: 46 additions & 0 deletions apps/workspace-engine/pkg/kafka/consumer.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package kafka

import (
"context"
"math"
"workspace-engine/pkg/db"

"github.com/charmbracelet/log"
"github.com/confluentinc/confluent-kafka-go/v2/kafka"
)
Expand All @@ -27,3 +31,45 @@ func createConsumer() (*kafka.Consumer, error) {

return c, nil
}

func getEarliestOffset(snapshots map[string]*db.WorkspaceSnapshot) int64 {
beginning := int64(kafka.OffsetBeginning)
if len(snapshots) == 0 {
return beginning
}

earliestOffset := int64(math.MaxInt64)
for _, snapshot := range snapshots {
if snapshot.Offset < earliestOffset {
earliestOffset = snapshot.Offset
}
}
if earliestOffset == math.MaxInt64 {
return beginning
}
return earliestOffset
}

func setOffsets(ctx context.Context, consumer *kafka.Consumer, partitionWorkspaceMap map[int32][]string) {
for partition, workspaceIDs := range partitionWorkspaceMap {
snapshots, err := db.GetLatestWorkspaceSnapshots(ctx, workspaceIDs)
if err != nil {
log.Error("Failed to get latest workspace snapshots", "error", err)
continue
}

earliestOffset := getEarliestOffset(snapshots)
effectiveOffset := earliestOffset
if effectiveOffset > 0 {
effectiveOffset = effectiveOffset + 1
}
if err := consumer.Seek(kafka.TopicPartition{
Topic: &Topic,
Partition: partition,
Offset: kafka.Offset(effectiveOffset),
}, 0); err != nil {
log.Error("Failed to seek to earliest offset", "error", err)
continue
}
}
}
Comment on lines +57 to +79
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Partition mismatch risk: filter snapshots by current partition before computing earliest offset.

A workspace’s latest snapshot might be from a different partition (e.g., after partition-count changes). Using that offset to seek this partition is incorrect.

Apply this diff:

-        snapshots, err := db.GetLatestWorkspaceSnapshots(ctx, workspaceIDs)
+        snapshots, err := db.GetLatestWorkspaceSnapshots(ctx, workspaceIDs)
         if err != nil {
             log.Error("Failed to get latest workspace snapshots", "error", err)
             continue
         }
 
-        earliestOffset := getEarliestOffset(snapshots)
+        // Only consider snapshots that belong to this partition.
+        filtered := make(map[string]*db.WorkspaceSnapshot, len(snapshots))
+        for wid, s := range snapshots {
+            if s != nil && s.Partition == partition {
+                filtered[wid] = s
+            }
+        }
+        earliestOffset := getEarliestOffset(filtered)
         effectiveOffset := earliestOffset
-        if effectiveOffset > 0 {
-            effectiveOffset = effectiveOffset + 1
+        if effectiveOffset >= 0 {
+            effectiveOffset++
         }
         if err := consumer.Seek(kafka.TopicPartition{
             Topic:     &Topic,
             Partition: partition,
             Offset:    kafka.Offset(effectiveOffset),
         }, 0); err != nil {
-            log.Error("Failed to seek to earliest offset", "error", err)
+            log.Error("Failed to seek", "partition", partition, "targetOffset", effectiveOffset, "error", err)
             continue
         }
🤖 Prompt for AI Agents
In apps/workspace-engine/pkg/kafka/consumer.go around lines 53 to 75, the code
calls db.GetLatestWorkspaceSnapshots for a set of workspace IDs then computes an
earliestOffset from all returned snapshots but doesn't restrict them to the
current Kafka partition; this can produce an incorrect offset when snapshots
belong to other partitions. Fix by filtering the snapshots returned by
GetLatestWorkspaceSnapshots to only those whose Partition equals the current
partition variable before calling getEarliestOffset; if the filtered slice is
empty, log/continue and skip seeking for that partition, otherwise compute
earliestOffset from the filtered snapshots and proceed with the existing
effectiveOffset/Seek logic.

98 changes: 85 additions & 13 deletions apps/workspace-engine/pkg/kafka/kafka.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,30 @@ package kafka

import (
"context"
"encoding/json"
"fmt"
"os"
"strconv"
"time"

"workspace-engine/pkg/db"
"workspace-engine/pkg/events"
eventHanlder "workspace-engine/pkg/events/handler"
"workspace-engine/pkg/oapi"
"workspace-engine/pkg/workspace"
wskafka "workspace-engine/pkg/workspace/kafka"

"github.com/aws/smithy-go/ptr"
"github.com/charmbracelet/log"
"github.com/confluentinc/confluent-kafka-go/v2/kafka"
)

// Configuration variables loaded from environment
var (
Topic = getEnv("KAFKA_TOPIC", "workspace-events")
GroupID = getEnv("KAFKA_GROUP_ID", "workspace-engine")
Brokers = getEnv("KAFKA_BROKERS", "localhost:9092")
Topic = getEnv("KAFKA_TOPIC", "workspace-events")
GroupID = getEnv("KAFKA_GROUP_ID", "workspace-engine")
Brokers = getEnv("KAFKA_BROKERS", "localhost:9092")
MinSnapshotDistance = getEnvInt("SNAPSHOT_DISTANCE_MINUTES", 60)
)

// getEnv retrieves an environment variable or returns a default value
Expand All @@ -32,6 +37,40 @@ func getEnv(varName string, defaultValue string) string {
return v
}

// getEnvInt retrieves an integer environment variable or returns a default value
func getEnvInt(varName string, defaultValue int64) int64 {
v := os.Getenv(varName)
if v == "" {
return defaultValue
}
i, err := strconv.ParseInt(v, 10, 64)
if err != nil {
log.Warn("Failed to parse environment variable as integer, using default", "var", varName, "value", v, "default", defaultValue)
return defaultValue
}
return i
}

func getLastSnapshot(ctx context.Context, msg *kafka.Message) (*db.WorkspaceSnapshot, error) {
var rawEvent eventHanlder.RawEvent
if err := json.Unmarshal(msg.Value, &rawEvent); err != nil {
log.Error("Failed to unmarshal event", "error", err, "message", string(msg.Value))
return nil, fmt.Errorf("failed to unmarshal event: %w", err)
}

return db.GetWorkspaceSnapshot(ctx, rawEvent.WorkspaceID)
}
Comment on lines +39 to +47
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Don’t log full payloads on JSON errors.

Logging string(msg.Value) risks PII leakage and noisy logs.

 func getLastSnapshot(ctx context.Context, msg *kafka.Message) (*db.WorkspaceSnapshot, error) {
-    var rawEvent eventHanlder.RawEvent
+    var rawEvent eventHanlder.RawEvent
     if err := json.Unmarshal(msg.Value, &rawEvent); err != nil {
-        log.Error("Failed to unmarshal event", "error", err, "message", string(msg.Value))
+        log.Error("Failed to unmarshal event", "error", err,
+            "topic", *msg.TopicPartition.Topic,
+            "partition", msg.TopicPartition.Partition,
+            "offset", msg.TopicPartition.Offset)
         return nil, fmt.Errorf("failed to unmarshal event: %w", err)
     }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
func getLastSnapshot(ctx context.Context, msg *kafka.Message) (*db.WorkspaceSnapshot, error) {
var rawEvent eventHanlder.RawEvent
if err := json.Unmarshal(msg.Value, &rawEvent); err != nil {
log.Error("Failed to unmarshal event", "error", err, "message", string(msg.Value))
return nil, fmt.Errorf("failed to unmarshal event: %w", err)
}
return db.GetWorkspaceSnapshot(ctx, rawEvent.WorkspaceID)
}
func getLastSnapshot(ctx context.Context, msg *kafka.Message) (*db.WorkspaceSnapshot, error) {
var rawEvent eventHanlder.RawEvent
if err := json.Unmarshal(msg.Value, &rawEvent); err != nil {
log.Error("Failed to unmarshal event", "error", err,
"topic", *msg.TopicPartition.Topic,
"partition", msg.TopicPartition.Partition,
"offset", msg.TopicPartition.Offset)
return nil, fmt.Errorf("failed to unmarshal event: %w", err)
}
return db.GetWorkspaceSnapshot(ctx, rawEvent.WorkspaceID)
}
🤖 Prompt for AI Agents
In apps/workspace-engine/pkg/kafka/kafka.go around lines 39 to 47, the current
error log prints the full message payload (string(msg.Value)) which can leak PII
and produce noisy logs; change the logging to avoid printing the full payload —
instead log only safe metadata (e.g., message key, topic/partition/offset if
available, workspace ID parsed from the partial/unmarshaled data if present) and
a truncated or hashed representation of the payload (or its length), then return
the error; remove direct string(msg.Value) from logs and replace it with a
sanitized indicator (e.g., payload length or first N bytes or a SHA256 hex) so
debugging info remains useful without exposing full data.


func getLastWorkspaceOffset(snapshot *db.WorkspaceSnapshot) int64 {
beginning := int64(kafka.OffsetBeginning)

if snapshot == nil {
return beginning
}

return snapshot.Offset
}
Comment on lines +49 to +57
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Guard against cross-partition/topology mismatches when using stored offsets.

Using a stored offset from a different partition or topology can skip/duplicate processing.

-func getLastWorkspaceOffset(snapshot *db.WorkspaceSnapshot) int64 {
+func getLastWorkspaceOffset(snapshot *db.WorkspaceSnapshot, currentPartition int32, currentNumPartitions int32) int64 {
     beginning := int64(kafka.OffsetBeginning)
 
     if snapshot == nil {
         return beginning
     }
 
-    return snapshot.Offset
+    if snapshot.Partition != currentPartition || snapshot.NumPartitions != currentNumPartitions {
+        return beginning
+    }
+    return snapshot.Offset
 }
@@
-    lastWorkspaceOffset := getLastWorkspaceOffset(lastSnapshot)
+    lastWorkspaceOffset := getLastWorkspaceOffset(
+        lastSnapshot,
+        msg.TopicPartition.Partition,
+        numPartitions,
+    )

Also applies to: 175-182

🤖 Prompt for AI Agents
In apps/workspace-engine/pkg/kafka/kafka.go around lines 49-57 (and similarly at
175-182), the function currently returns a stored offset without validating it
against the current partition/topology; update the logic to verify that the
snapshot is non-nil and that its Partition (and any topology identifier used by
this service) matches the current consumer partition/topology before returning
snapshot.Offset; if the partition/topology differs or the snapshot lacks
matching identifiers, ignore the stored offset and return kafka.OffsetBeginning
(or the earliest safe offset) to avoid skipping/duplicating messages. Ensure you
add the necessary fields/checks (partition and topology id) and unit tests for
both matching and mismatching cases.


// RunConsumerWithWorkspaceLoader starts the Kafka consumer with workspace-based offset resume
//
// Flow:
Expand Down Expand Up @@ -83,11 +122,17 @@ func RunConsumer(ctx context.Context) error {
}
log.Info("Partition assignment complete", "assigned", assignedPartitions)

allWorkspaceIDs, err := wskafka.GetAssignedWorkspaceIDs(ctx, assignedPartitions, numPartitions)
partitionWorkspaceMap, err := wskafka.GetAssignedWorkspaceIDs(ctx, assignedPartitions, numPartitions)
if err != nil {
return fmt.Errorf("failed to get assigned workspace IDs: %w", err)
}

// Flatten the map to get all workspace IDs
var allWorkspaceIDs []string
for _, workspaceIDs := range partitionWorkspaceMap {
allWorkspaceIDs = append(allWorkspaceIDs, workspaceIDs...)
}

storage := workspace.NewFileStorage("./state")
if workspace.IsGCSStorageEnabled() {
storage, err = workspace.NewGCSStorageClient(ctx)
Expand Down Expand Up @@ -118,6 +163,8 @@ func RunConsumer(ctx context.Context) error {
// Start consuming messages
handler := events.NewEventHandler()

setOffsets(ctx, consumer, partitionWorkspaceMap)

for {
// Check for cancellation
select {
Expand All @@ -134,21 +181,46 @@ func RunConsumer(ctx context.Context) error {
continue
}

ws, err := handler.ListenAndRoute(ctx, msg)
lastSnapshot, err := getLastSnapshot(ctx, msg)
if err != nil {
log.Error("Failed to route message", "error", err)
log.Error("Failed to get last snapshot", "error", err)
continue
}

snapshot := &db.WorkspaceSnapshot{
Path: fmt.Sprintf("%s.gob", ws.ID),
Timestamp: msg.Timestamp,
Partition: int32(msg.TopicPartition.Partition),
NumPartitions: numPartitions,
messageOffset := int64(msg.TopicPartition.Offset)
lastCommittedOffset, err := getCommittedOffset(consumer, msg.TopicPartition.Partition)
if err != nil {
log.Error("Failed to get committed offset", "error", err)
continue
}
lastWorkspaceOffset := getLastWorkspaceOffset(lastSnapshot)

offsetTracker := eventHanlder.OffsetTracker{
LastCommittedOffset: lastCommittedOffset,
LastWorkspaceOffset: lastWorkspaceOffset,
MessageOffset: messageOffset,
}

ws, err := handler.ListenAndRoute(ctx, msg, offsetTracker)
if err != nil {
log.Error("Failed to route message", "error", err)
continue
}

if err := workspace.Save(ctx, storage, ws, snapshot); err != nil {
log.Error("Failed to save workspace", "workspaceID", ws.ID, "snapshotPath", snapshot.Path, "error", err)
shouldSaveSnapshot := lastSnapshot == nil || lastSnapshot.Timestamp.Before(msg.Timestamp.Add(-time.Duration(MinSnapshotDistance)*time.Minute))
if shouldSaveSnapshot {
snapshot := &db.WorkspaceSnapshot{
WorkspaceID: ws.ID,
Path: fmt.Sprintf("%s.gob", ws.ID),
Timestamp: msg.Timestamp,
Partition: int32(msg.TopicPartition.Partition),
Offset: int64(msg.TopicPartition.Offset),
NumPartitions: numPartitions,
}

if err := workspace.Save(ctx, storage, ws, snapshot); err != nil {
log.Error("Failed to save workspace", "workspaceID", ws.ID, "snapshotPath", snapshot.Path, "error", err)
}
}

// Commit offset to Kafka
Expand Down
23 changes: 23 additions & 0 deletions apps/workspace-engine/pkg/kafka/offset.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,26 @@ func getTopicPartitionCount(c *kafka.Consumer) (int32, error) {

return numPartitions, nil
}

// getCommittedOffset retrieves the last committed offset for a partition
func getCommittedOffset(consumer *kafka.Consumer, partition int32) (int64, error) {
partitions := []kafka.TopicPartition{
{
Topic: &Topic,
Partition: partition,
Offset: kafka.OffsetStored, // This fetches the committed offset
},
}

committed, err := consumer.Committed(partitions, 5000)
if err != nil {
return int64(kafka.OffsetInvalid), err
}

if len(committed) == 0 || committed[0].Offset == kafka.OffsetInvalid {
// No committed offset yet, this is the beginning
return int64(kafka.OffsetBeginning), nil
}

return int64(committed[0].Offset), nil
}
Loading
Loading