Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions docs/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -403,3 +403,14 @@ cog predict --use-replicate-token -i prompt="Hello"
# Multiple environment variables
cog run -e CUDA_VISIBLE_DEVICES=0 -e BATCH_SIZE=32 python train.py
```

# Selecting Ubuntu version for CUDA base image

To select a specific Ubuntu version for the CUDA base image, set the environment variable `COG_UBUNTU_VERSION` before building:

```bash
export COG_UBUNTU_VERSION=22.04
cog build --use-cog-base-image=false
```

If not set, the latest supported Ubuntu version will be used.
12 changes: 10 additions & 2 deletions pkg/config/compatibility.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
_ "embed"
"encoding/json"
"fmt"
"os"
"sort"
"strings"

Expand Down Expand Up @@ -255,13 +256,20 @@ func versionGreater(a string, b string) (bool, error) {

func CUDABaseImageFor(cuda string, cuDNN string) (string, error) {
var images []CUDABaseImage
ubuntuEnv := os.Getenv("COG_UBUNTU_VERSION")
for _, image := range CUDABaseImages {
if version.Matches(cuda, image.CUDA) && image.CuDNN == cuDNN {
images = append(images, image)
if ubuntuEnv == "" || image.Ubuntu == ubuntuEnv {
images = append(images, image)
}
}
}
if len(images) == 0 {
return "", fmt.Errorf("No matching base image for CUDA %s and CuDNN %s", cuda, cuDNN)
ubuntuMsg := ubuntuEnv
if ubuntuEnv == "" {
ubuntuMsg = "any"
}
return "", fmt.Errorf("No matching base image for CUDA %s, CuDNN %s, Ubuntu %s", cuda, cuDNN, ubuntuMsg)
}

sort.Slice(images, func(i, j int) bool {
Expand Down
40 changes: 40 additions & 0 deletions pkg/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,46 @@ func TestCUDABaseImageTag(t *testing.T) {
require.Equal(t, "nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04", imageTag)
}

func TestCUDABaseImageTagWithUbuntuEnv(t *testing.T) {
// By default, CUDA 12.8 + Python 3.12 should select Ubuntu 24.04
os.Unsetenv("COG_UBUNTU_VERSION")
configDefault := &Config{
Build: &Build{
GPU: true,
PythonVersion: "3.12",
CUDA: "12.8.0",
CuDNN: "9",
},
}

err := configDefault.ValidateAndComplete("")
require.NoError(t, err)

imageTag, err := configDefault.CUDABaseImageTag()
require.NoError(t, err)
require.Equal(t, "nvidia/cuda:12.8.0-cudnn-devel-ubuntu24.04", imageTag)

// If COG_UBUNTU_VERSION is set to 22.04, should select Ubuntu 22.04 image
os.Setenv("COG_UBUNTU_VERSION", "22.04")
configEnv := &Config{
Build: &Build{
GPU: true,
PythonVersion: "3.12",
CUDA: "12.8.0",
CuDNN: "9",
},
}

err = configEnv.ValidateAndComplete("")
require.NoError(t, err)

imageTag, err = configEnv.CUDABaseImageTag()
require.NoError(t, err)
require.Equal(t, "nvidia/cuda:12.8.0-cudnn-devel-ubuntu22.04", imageTag)

os.Unsetenv("COG_UBUNTU_VERSION")
}

func TestBuildRunItemStringYAML(t *testing.T) {
type BuildWrapper struct {
Build *Build `yaml:"build"`
Expand Down