Skip to content

Commit 997e02b

Browse files
authored
Handle tags at the end of images (#2401)
* We weren’t correctly separating the image name from the tag
1 parent db0ab69 commit 997e02b

File tree

2 files changed

+83
-2
lines changed

2 files changed

+83
-2
lines changed

pkg/api/client.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -378,12 +378,12 @@ func decomposeImageName(image string) (string, string, string, string, error) {
378378
if imageComponents[0] != global.ReplicateRegistryHost {
379379
return "", "", "", "", r8_errors.ErrorBadRegistryHost
380380
}
381-
tagComponents := strings.Split(image, ":")
381+
tagComponents := strings.Split(imageComponents[2], ":")
382382
tag := ""
383383
if len(tagComponents) == 2 {
384384
tag = tagComponents[1]
385385
}
386-
return imageComponents[0], imageComponents[1], imageComponents[2], tag, nil
386+
return imageComponents[0], imageComponents[1], tagComponents[0], tag, nil
387387
}
388388

389389
func decomposeDraftSlug(slug string) (string, string, error) {

pkg/api/client_test.go

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,3 +251,84 @@ func TestPullDraftSource(t *testing.T) {
251251
})
252252
require.NoError(t, err)
253253
}
254+
255+
func TestPullSourceWithTag(t *testing.T) {
256+
// Create file to pull
257+
dir := t.TempDir()
258+
predictPyPath := filepath.Join(dir, "predict.py")
259+
handle, err := os.Create(predictPyPath)
260+
require.NoError(t, err)
261+
handle.WriteString("import cog")
262+
err = handle.Close()
263+
require.NoError(t, err)
264+
info, err := os.Stat(predictPyPath)
265+
require.NoError(t, err)
266+
267+
// Setup mock web server for cog.replicate.com (token exchange)
268+
webServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
269+
switch r.URL.Path {
270+
case "/api/token/user":
271+
// Mock token exchange response
272+
//nolint:gosec
273+
tokenResponse := `{
274+
"keys": {
275+
"cog": {
276+
"key": "test-api-token",
277+
"expires_at": "2024-12-31T23:59:59Z"
278+
}
279+
}
280+
}`
281+
w.WriteHeader(http.StatusOK)
282+
w.Write([]byte(tokenResponse))
283+
default:
284+
w.WriteHeader(http.StatusNotFound)
285+
}
286+
}))
287+
defer webServer.Close()
288+
289+
// Setup mock API server for api.replicate.com (model and source endpoints)
290+
apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
291+
switch r.URL.Path {
292+
case "/v1/models/user/test/versions/12435/source":
293+
// Mock source pull endpoint
294+
var buf bytes.Buffer
295+
tw := tar.NewWriter(&buf)
296+
header, err := tar.FileInfoHeader(info, info.Name())
297+
require.NoError(t, err)
298+
header.Name = "predict.py"
299+
err = tw.WriteHeader(header)
300+
require.NoError(t, err)
301+
file, err := os.Open(predictPyPath)
302+
require.NoError(t, err)
303+
defer file.Close()
304+
_, err = io.Copy(tw, file)
305+
require.NoError(t, err)
306+
err = tw.Close()
307+
require.NoError(t, err)
308+
w.WriteHeader(http.StatusOK)
309+
w.Write(buf.Bytes())
310+
default:
311+
w.WriteHeader(http.StatusNotFound)
312+
}
313+
}))
314+
defer apiServer.Close()
315+
316+
webURL, err := url.Parse(webServer.URL)
317+
require.NoError(t, err)
318+
apiURL, err := url.Parse(apiServer.URL)
319+
require.NoError(t, err)
320+
321+
t.Setenv(env.SchemeEnvVarName, webURL.Scheme)
322+
t.Setenv(env.WebHostEnvVarName, webURL.Host)
323+
t.Setenv(env.APIHostEnvVarName, apiURL.Host)
324+
325+
// Setup mock command
326+
command := dockertest.NewMockCommand()
327+
webClient := web.NewClient(command, http.DefaultClient)
328+
329+
client := NewClient(command, http.DefaultClient, webClient)
330+
err = client.PullSource(t.Context(), "r8.im/user/test:12435", func(header *tar.Header, tr *tar.Reader) error {
331+
return nil
332+
})
333+
require.NoError(t, err)
334+
}

0 commit comments

Comments
 (0)