Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
9568145
Add mode for cog-runtime to use signals
NikhilSinha1 Oct 6, 2025
dd441e6
Use proper type hint
NikhilSinha1 Oct 6, 2025
dae112b
Make sure IPC URL exists when calling send_ipc
NikhilSinha1 Oct 6, 2025
71f255b
Linting
NikhilSinha1 Oct 6, 2025
29d6d6a
Use func
NikhilSinha1 Oct 6, 2025
17d6777
Wire up checkpointing in cog-runtime
NikhilSinha1 Oct 7, 2025
2f5cf99
Linting
NikhilSinha1 Oct 7, 2025
36b57c7
Lint fixes part 1
NikhilSinha1 Oct 7, 2025
636fdd8
nolint for the checkpointer file
NikhilSinha1 Oct 7, 2025
3075bad
Lint fixes part 2
NikhilSinha1 Oct 7, 2025
4aaf4bb
Write ready file when ready
NikhilSinha1 Oct 7, 2025
d44231e
Send ready signal
NikhilSinha1 Oct 7, 2025
ecbc90b
Standardize flag casing
NikhilSinha1 Oct 8, 2025
55073f3
Testing
NikhilSinha1 Oct 9, 2025
f14484b
Testing further
NikhilSinha1 Oct 9, 2025
99c0bb0
Ordering
NikhilSinha1 Oct 9, 2025
c0b9ff4
No shadowing
NikhilSinha1 Oct 9, 2025
ae1163d
Close TCP connections
NikhilSinha1 Oct 9, 2025
87eeac8
Correct perms, as well as some testing
NikhilSinha1 Oct 9, 2025
96666c3
Remove test logging
NikhilSinha1 Oct 9, 2025
c10f2ba
Comments
NikhilSinha1 Oct 9, 2025
056b883
Comments
NikhilSinha1 Oct 9, 2025
4a15dcd
Pass context to function
NikhilSinha1 Oct 10, 2025
64171ab
Linter
NikhilSinha1 Oct 10, 2025
788ccbe
Testing
NikhilSinha1 Oct 10, 2025
7dccc14
Testing
NikhilSinha1 Oct 10, 2025
74a6134
Move assignment earlier
NikhilSinha1 Oct 10, 2025
bb7a47e
Testing
NikhilSinha1 Oct 10, 2025
a1366c4
Do we need shell-job?
NikhilSinha1 Oct 13, 2025
f85fd2e
Remove stuff only needed for non-cog runtime
NikhilSinha1 Oct 13, 2025
df9a17f
Testing
NikhilSinha1 Oct 13, 2025
3b8fcfa
Return shell job to command
NikhilSinha1 Oct 13, 2025
a79c0f5
Testing v2
NikhilSinha1 Oct 13, 2025
bc1e608
Remove testing sleep
NikhilSinha1 Oct 14, 2025
9058499
Add verbose logging
NikhilSinha1 Oct 14, 2025
32c2e81
Different verbose logging cmd
NikhilSinha1 Oct 14, 2025
d77288b
Print logs from cuda checkpoint
NikhilSinha1 Oct 14, 2025
f0ebc37
More logging
NikhilSinha1 Oct 14, 2025
44d0a1e
log instead
NikhilSinha1 Oct 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cmd/cog/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ type ServerCmd struct {
UseProcedureMode bool `help:"Enable procedure mode for concurrent predictions" name:"use-procedure-mode" env:"COG_USE_PROCEDURE_MODE"`
AwaitExplicitShutdown bool `help:"Wait for explicit shutdown signal instead of auto-shutdown" name:"await-explicit-shutdown" env:"COG_AWAIT_EXPLICIT_SHUTDOWN"`
OneShot bool `help:"Enable one-shot mode (single runner, wait for cleanup before ready)" name:"one-shot" env:"COG_ONE_SHOT"`
SignalMode bool `help:"Enable signal mode (use signals instead of webhooks for IPC communication)" name:"signal-mode" env:"COG_SIGNAL_MODE"`
UploadURL string `help:"Base URL for uploading prediction output files" name:"upload-url" env:"COG_UPLOAD_URL"`
WorkingDirectory string `help:"Override the working directory for predictions" name:"working-directory" env:"COG_WORKING_DIRECTORY"`
RunnerShutdownGracePeriod time.Duration `help:"Grace period before force-killing prediction runners" name:"runner-shutdown-grace-period" default:"600s" env:"COG_RUNNER_SHUTDOWN_GRACE_PERIOD"`
Expand Down Expand Up @@ -78,6 +79,7 @@ func buildServiceConfig(s *ServerCmd) (config.Config, error) {
WorkingDirectory: workingDir,
UploadURL: s.UploadURL,
IPCUrl: fmt.Sprintf("http://localhost:%d/_ipc", s.Port),
SignalMode: s.SignalMode,
MaxRunners: s.MaxRunners,
RunnerShutdownGracePeriod: s.RunnerShutdownGracePeriod,
CleanupTimeout: s.CleanupTimeout,
Expand Down
259 changes: 259 additions & 0 deletions internal/checkpointer/checkpointer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
// There are some commands in here that are susceptible to injection. However, cog
// is a vehicle to let people run their own code... so why go through the hassle of
// injection? Cog is not run with any more permissions than the user code.
//
//nolint:gosec // See above
package checkpointer

import (
"context"
"errors"
"fmt"
"os"
"os/exec"
"path/filepath"
"strconv"
"strings"
"time"

"github.com/replicate/cog-runtime/internal/logging"
)

const (
// Configuration environment variables
locationEnvVar = "R8_LOCATION"
shouldCheckpointEnvVar = "R8_CUDA_CHECKPOINT"
leaseFileEnvVar = "R8_LEASE_FILE"
cudaCheckpointDirEnvVar = "R8_CUDA_CHECKPOINT_DIR"
cudaReadyFileEnvVar = "R8_CUDA_READY_LOCK_FILE"

// Dependencies for the checkpoint process
cudaCheckpointURLFmtStr = "https://r8-public-assets-%s.cwobject.com/cuda-checkpoint"
criuURLFmtStr = "https://r8-public-assets-%s.cwobject.com/criu.tar.gz"
cudaCheckpointPath = "/tmp/cuda-checkpoint"
criuPath = "/tmp/criu"

// Metadata storage paths
checkpointSubdirName = "checkpoint"
)

var errNoCheckpointDir = errors.New("could not find checkpoint directory environment variable")

type FatalCheckpointError struct {
err error
}

func (e *FatalCheckpointError) Error() string {
return e.err.Error()
}

type Checkpointer interface {
Disable()
HasCheckpoint() bool
Prepare(ctx context.Context) error
Checkpoint(ctx context.Context, cmd *exec.Cmd, waitFunc func() error) error
Restore(ctx context.Context) (*exec.Cmd, func(context.Context) error, error)
WriteReadyFile() error
}

type checkpointer struct {
enabled bool
hasCheckpoint bool
checkpointDir string
leaseFile string
log *logging.SugaredLogger
}

func NewCheckpointer(ctx context.Context, log *logging.SugaredLogger) Checkpointer {
return &checkpointer{
enabled: os.Getenv(shouldCheckpointEnvVar) == "true",
checkpointDir: os.Getenv(cudaCheckpointDirEnvVar),
leaseFile: os.Getenv(leaseFileEnvVar),
log: log,
}
}

func (c *checkpointer) Disable() {
c.enabled = false
}

func (c *checkpointer) HasCheckpoint() bool {
if !c.enabled {
return false
}

return c.hasCheckpoint
}

func (c *checkpointer) Prepare(ctx context.Context) error {
if !c.enabled {
return nil
}

// Download dependencies
err := downloadCUDACheckpointBinaries(ctx)
if err != nil {
return err
}

// Wait for IPC lease file to be deleted
if c.leaseFile != "" {
err = pollForFileDeletion(c.leaseFile, 5*time.Minute, 10*time.Second)
if err != nil {
return err
}
}

empty, err := isDirEmpty(filepath.Join(c.checkpointDir, checkpointSubdirName))
// If the err is not nil, it probably means the directory does not exist
if err == nil && !empty {
c.hasCheckpoint = true
}

return nil
}

func (c *checkpointer) Checkpoint(ctx context.Context, cogletCmd *exec.Cmd, waitFunc func() error) error {
if !c.enabled {
return nil
}

if c.checkpointDir == "" {
return errNoCheckpointDir
}

if err := waitFunc(); err != nil {
return err
}

err := os.MkdirAll(filepath.Join(c.checkpointDir, checkpointSubdirName), 0o666)
if err != nil {
return err
}

pid := strconv.Itoa(cogletCmd.Process.Pid)

// Find the PID of the command that is actually using the GPU
cudaPIDBytes, err := exec.CommandContext(ctx, "nvidia-smi", "--query-compute-apps=pid", "--format=csv,noheader").Output()
if err != nil {
return err
}

cudaPID := strings.TrimSpace(string(cudaPIDBytes))

// Toggle CUDA off
cmd := exec.CommandContext(ctx, cudaCheckpointPath, "--toggle", "--pid", cudaPID)
if err := cmd.Run(); err != nil {
return err
}

// CRIU checkpoint (leaving process running)
cmd = exec.CommandContext(ctx, criuPath, "dump", "--shell-job", "--leave-running", "--tcp-close", "--images-dir", filepath.Join(c.checkpointDir, checkpointSubdirName), "--tree", pid)
if err := cmd.Run(); err != nil {
// Try to toggle CUDA back on. If we aren't able to restart CUDA, the process
// will hang indefinitely, so we should kill it and try to start a new one
// without checkpointing
cmd = exec.CommandContext(ctx, cudaCheckpointPath, "--toggle", "--pid", cudaPID)
if cudaErr := cmd.Run(); cudaErr != nil {
// Return a fatal error so upstream knows we cannot continue in the current state
return &FatalCheckpointError{
err: cudaErr,
}
}
// Return the original checkpointing error
return err
}

// Toggle CUDA back on. If we aren't able to restart CUDA, the process
// will hang indefinitely, so we should kill it and try to start a new
// one without checkpointing
cmd = exec.CommandContext(ctx, cudaCheckpointPath, "--toggle", "--pid", cudaPID)
if err := cmd.Run(); err != nil {
// Return a fatal error so upstream knows we cannot continue in the current state
return &FatalCheckpointError{
err: err,
}
}

return nil
}

func (c *checkpointer) Restore(ctx context.Context) (*exec.Cmd, func(context.Context) error, error) {
if !c.enabled {
return nil, nil, nil
}

// Set up restore command
restoreCmd := exec.CommandContext(ctx, criuPath, "restore", "--shell-job", "--tcp-close", "--images-dir", filepath.Join(c.checkpointDir, checkpointSubdirName))

// Set up callback function once restore is started
callback := func(con context.Context) error {
out, err := exec.CommandContext(con, "ps", "aux").Output()
if err != nil {
c.log.Infow(err.Error())
}
c.log.Infow(string(out))
c.log.Infow(strconv.Itoa(restoreCmd.Process.Pid))
// Toggle CUDA on for the restored process
cmd := exec.CommandContext(con, cudaCheckpointPath, "--toggle", "--pid", strconv.Itoa(restoreCmd.Process.Pid))
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
c.log.Errorw("failed to toggle CUDA on", "error", err)
// If this command failed, we want to best effort try to kill the started process,
// since we'll start a new one
killProcess(restoreCmd) //nolint:errcheck // This is just best effort

return err
}

return nil
}

// The restored command is a running instance of coglet
return restoreCmd, callback, nil
}

func killProcess(cmd *exec.Cmd) error {
err := cmd.Process.Kill()
if err != nil {
return err
}

// Wait for the process to exit with a 5 second timeout
done := make(chan error, 1)
go func() { done <- cmd.Wait() }()

select {
case err = <-done:
return err
case <-time.After(5 * time.Second):
return nil
}
}

func (c *checkpointer) WriteReadyFile() error {
// If it isn't expected, make this a no-op
if os.Getenv(shouldCheckpointEnvVar) != "true" {
return nil
}
return writeCudaReadyFile()
}

func downloadCUDACheckpointBinaries(ctx context.Context) error {
location := os.Getenv("R8_LOCATION")

// Download the cuda-checkpoint binary
err := downloadAndChmod(fmt.Sprintf(cudaCheckpointURLFmtStr, location), cudaCheckpointPath)
if err != nil {
return fmt.Errorf("failed to download and chmod cuda-checkpoint binary: %w", err)
}
// CRIU gets downloaded as a tar with its dependencies. So we need to extract the tar, then
// link the LD_LIBRARY_PATH to the dependencies
dir := filepath.Dir(criuPath)
err = downloadAndUntar(ctx, fmt.Sprintf(criuURLFmtStr, location), dir)
if err != nil {
return fmt.Errorf("failed to download and untar CRIU: %w", err)
}
return updateEnvVar("LD_LIBRARY_PATH", filepath.Join(dir, "criu-lib"))
}
Loading
Loading