@@ -251,3 +251,84 @@ func TestPullDraftSource(t *testing.T) {
251
251
})
252
252
require .NoError (t , err )
253
253
}
254
+
255
+ func TestPullSourceWithTag (t * testing.T ) {
256
+ // Create file to pull
257
+ dir := t .TempDir ()
258
+ predictPyPath := filepath .Join (dir , "predict.py" )
259
+ handle , err := os .Create (predictPyPath )
260
+ require .NoError (t , err )
261
+ handle .WriteString ("import cog" )
262
+ err = handle .Close ()
263
+ require .NoError (t , err )
264
+ info , err := os .Stat (predictPyPath )
265
+ require .NoError (t , err )
266
+
267
+ // Setup mock web server for cog.replicate.com (token exchange)
268
+ webServer := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
269
+ switch r .URL .Path {
270
+ case "/api/token/user" :
271
+ // Mock token exchange response
272
+ //nolint:gosec
273
+ tokenResponse := `{
274
+ "keys": {
275
+ "cog": {
276
+ "key": "test-api-token",
277
+ "expires_at": "2024-12-31T23:59:59Z"
278
+ }
279
+ }
280
+ }`
281
+ w .WriteHeader (http .StatusOK )
282
+ w .Write ([]byte (tokenResponse ))
283
+ default :
284
+ w .WriteHeader (http .StatusNotFound )
285
+ }
286
+ }))
287
+ defer webServer .Close ()
288
+
289
+ // Setup mock API server for api.replicate.com (model and source endpoints)
290
+ apiServer := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
291
+ switch r .URL .Path {
292
+ case "/v1/models/user/test/versions/12435/source" :
293
+ // Mock source pull endpoint
294
+ var buf bytes.Buffer
295
+ tw := tar .NewWriter (& buf )
296
+ header , err := tar .FileInfoHeader (info , info .Name ())
297
+ require .NoError (t , err )
298
+ header .Name = "predict.py"
299
+ err = tw .WriteHeader (header )
300
+ require .NoError (t , err )
301
+ file , err := os .Open (predictPyPath )
302
+ require .NoError (t , err )
303
+ defer file .Close ()
304
+ _ , err = io .Copy (tw , file )
305
+ require .NoError (t , err )
306
+ err = tw .Close ()
307
+ require .NoError (t , err )
308
+ w .WriteHeader (http .StatusOK )
309
+ w .Write (buf .Bytes ())
310
+ default :
311
+ w .WriteHeader (http .StatusNotFound )
312
+ }
313
+ }))
314
+ defer apiServer .Close ()
315
+
316
+ webURL , err := url .Parse (webServer .URL )
317
+ require .NoError (t , err )
318
+ apiURL , err := url .Parse (apiServer .URL )
319
+ require .NoError (t , err )
320
+
321
+ t .Setenv (env .SchemeEnvVarName , webURL .Scheme )
322
+ t .Setenv (env .WebHostEnvVarName , webURL .Host )
323
+ t .Setenv (env .APIHostEnvVarName , apiURL .Host )
324
+
325
+ // Setup mock command
326
+ command := dockertest .NewMockCommand ()
327
+ webClient := web .NewClient (command , http .DefaultClient )
328
+
329
+ client := NewClient (command , http .DefaultClient , webClient )
330
+ err = client .PullSource (t .Context (), "r8.im/user/test:12435" , func (header * tar.Header , tr * tar.Reader ) error {
331
+ return nil
332
+ })
333
+ require .NoError (t , err )
334
+ }
0 commit comments