Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions pkg/cli/predict.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@ import (
"github.com/replicate/cog/pkg/config"
"github.com/replicate/cog/pkg/docker"
"github.com/replicate/cog/pkg/docker/command"
"github.com/replicate/cog/pkg/http"
"github.com/replicate/cog/pkg/image"
r8_path "github.com/replicate/cog/pkg/path"
"github.com/replicate/cog/pkg/predict"
"github.com/replicate/cog/pkg/registry"
"github.com/replicate/cog/pkg/util/console"
"github.com/replicate/cog/pkg/util/files"
"github.com/replicate/cog/pkg/util/mime"
"github.com/replicate/cog/pkg/web"
)

const StdinPath = "-"
Expand All @@ -42,6 +44,7 @@ var (
setupTimeout uint32
useReplicateAPIToken bool
inputJSON string
replicateUsername string
)

func newPredictCommand() *cobra.Command {
Expand Down Expand Up @@ -69,6 +72,7 @@ the prediction on that.`,
addFastFlag(cmd)
addLocalImage(cmd)
addConfigFlag(cmd)
addReplicateUsernameFlag(cmd)

cmd.Flags().StringArrayVarP(&inputFlags, "input", "i", []string{}, "Inputs, in the form name=value. if value is prefixed with @, then it is read from a file on disk. E.g. -i [email protected]")
cmd.Flags().StringVarP(&outPath, "output", "o", "", "Output path")
Expand Down Expand Up @@ -253,6 +257,19 @@ func cmdPredict(cmd *cobra.Command, args []string) error {
}
}

if replicateUsername != "" {
client, err := http.ProvideHTTPClient(ctx, dockerClient)
if err != nil {
return err
}
webClient := web.NewClient(dockerClient, client)
token, err := webClient.FetchAPIToken(ctx, replicateUsername)
if err != nil {
return err
}
envFlags = append(envFlags, fmt.Sprintf("REPLICATE_API_TOKEN=%s", token))
}

console.Info("")
console.Infof("Starting Docker image %s and running setup()...", imageName)

Expand Down Expand Up @@ -643,3 +660,7 @@ func parseInputFlags(inputs []string, schema *openapi3.T) (predict.Inputs, error
func addSetupTimeoutFlag(cmd *cobra.Command) {
cmd.Flags().Uint32Var(&setupTimeout, "setup-timeout", 5*60, "The timeout for a container to setup (in seconds).")
}

func addReplicateUsernameFlag(cmd *cobra.Command) {
cmd.Flags().StringVarP(&replicateUsername, "replicate-username", "u", "", "The principal to use if the prediction requires a token.")
}