1+ // Copyright 2025 The Kubeflow Authors
2+ //
3+ // Licensed under the Apache License, Version 2.0 (the "License");
4+ // you may not use this file except in compliance with the License.
5+ // You may obtain a copy of the License at
6+ //
7+ // https://www.apache.org/licenses/LICENSE-2.0
8+ //
9+ // Unless required by applicable law or agreed to in writing, software
10+ // distributed under the License is distributed on an "AS IS" BASIS,
11+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ // See the License for the specific language governing permissions and
13+ // limitations under the License.
14+
115package main
216
317import (
@@ -94,13 +108,14 @@ func parseDriverRequestArgs(r *http.Request) (*api.DriverPluginArgs, error) {
94108 if err := json .NewDecoder (r .Body ).Decode (& body ); err != nil {
95109 return nil , fmt .Errorf ("failed to parse driver request body: %v" , err )
96110 }
97- if body .Template == nil {
111+ switch {
112+ case body .Template == nil :
98113 return nil , fmt .Errorf ("driver request body.Template is empty" )
99- } else if body .Template .Plugin == nil {
114+ case body .Template .Plugin == nil :
100115 return nil , fmt .Errorf ("driver request body.Template.Plugin is empty" )
101- } else if body .Template .Plugin .DriverPlugin == nil {
116+ case body .Template .Plugin .DriverPlugin == nil :
102117 return nil , fmt .Errorf ("driver request body.Template.Plugin.DriverPlugin is empty" )
103- } else if body .Template .Plugin .DriverPlugin .Args == nil {
118+ case body .Template .Plugin .DriverPlugin .Args == nil :
104119 return nil , fmt .Errorf ("driver request body.Template.Plugin.Args is empty" )
105120 }
106121 args := body .Template .Plugin .DriverPlugin .Args
@@ -118,7 +133,7 @@ func drive(args api.DriverPluginArgs) (execution *driver.Execution, err error) {
118133 }()
119134 ctx := context .Background ()
120135
121- proxy .InitializeConfig (args .HttpProxy , args .HttpsProxy , args .NoProxy )
136+ proxy .InitializeConfig (args .HTTPProxy , args .HTTPSProxy , args .NoProxy )
122137
123138 glog .Infof ("input ComponentSpec:%s\n " , prettyPrint (args .Component ))
124139 componentSpec := & pipelinespec.ComponentSpec {}
@@ -149,7 +164,7 @@ func drive(args api.DriverPluginArgs) (execution *driver.Execution, err error) {
149164 return nil , fmt .Errorf ("failed to unmarshal runtime config, error: %w\n runtimeConfig: %v" , err , args .RuntimeConfig )
150165 }
151166 }
152- k8sExecCfg , err := parseExecConfigJson (& args .KubernetesConfig )
167+ k8sExecCfg , err := parseExecConfigJSON (& args .KubernetesConfig )
153168 if err != nil {
154169 return nil , err
155170 }
@@ -160,7 +175,9 @@ func drive(args api.DriverPluginArgs) (execution *driver.Execution, err error) {
160175 var tlsCfg * tls.Config
161176 if args .MetadataTLSEnabled {
162177 tlsCfg , err = util .GetTLSConfig (args .CACertPath )
163- return nil , fmt .Errorf ("unable to drive driver: failed to load TLS configuration: %v" , err )
178+ if err != nil {
179+ return nil , fmt .Errorf ("unable to drive driver: failed to load TLS configuration: %v" , err )
180+ }
164181 }
165182 client , err := newMlmdClient (tlsCfg )
166183 if err != nil {
@@ -180,23 +197,28 @@ func drive(args api.DriverPluginArgs) (execution *driver.Execution, err error) {
180197 return nil , fmt .Errorf ("failed to parse iteration index, error: %w" , err )
181198 }
182199 options := driver.Options {
183- PipelineName : args .PipelineName ,
184- RunID : args .RunID ,
185- RunName : args .RunName ,
186- RunDisplayName : args .RunDisplayName ,
187- Namespace : namespace ,
188- Component : componentSpec ,
189- Task : taskSpec ,
190- DAGExecutionID : dagExecutionID ,
191- IterationIndex : iterationIndex ,
192- PublishLogs : args .PublishLogs ,
193- CacheDisabled : args .CacheDisabledFlag ,
194- DriverType : args .Type ,
195- TaskName : args .TaskName ,
200+ PipelineName : args .PipelineName ,
201+ RunID : args .RunID ,
202+ RunName : args .RunName ,
203+ RunDisplayName : args .RunDisplayName ,
204+ Namespace : namespace ,
205+ Component : componentSpec ,
206+ Task : taskSpec ,
207+ DAGExecutionID : dagExecutionID ,
208+ IterationIndex : iterationIndex ,
209+ PublishLogs : args .PublishLogs ,
210+ CacheDisabled : args .CacheDisabledFlag ,
211+ DriverType : args .Type ,
212+ TaskName : args .TaskName ,
213+ MLPipelineTLSEnabled : args .MlPipelineTLSEnabled ,
214+ MLMDServerAddress : * mlmdServerAddress ,
215+ MLMDServerPort : * mlmdServerPort ,
216+ CaCertPath : args .CACertPath ,
196217 }
218+
197219 var driverErr error
198220 switch args .Type {
199- case ROOT_DAG :
221+ case RootDag :
200222 options .RuntimeConfig = runtimeConfig
201223 execution , driverErr = driver .RootDAG (ctx , options , client )
202224 case DAG :
@@ -224,19 +246,16 @@ func drive(args api.DriverPluginArgs) (execution *driver.Execution, err error) {
224246}
225247
226248func validate (args api.DriverPluginArgs ) error {
227- if args .Type == "" {
249+ switch {
250+ case args .Type == "" :
228251 return fmt .Errorf ("argument type must be specified" )
229- }
230- if args .HttpProxy == unsetProxyArgValue {
252+ case args .HTTPProxy == unsetProxyArgValue :
231253 return fmt .Errorf ("argument http_proxy is required but can be an empty value" )
232- }
233- if args .HttpsProxy == unsetProxyArgValue {
254+ case args .HTTPSProxy == unsetProxyArgValue :
234255 return fmt .Errorf ("argument https_proxy is required but can be an empty value" )
235- }
236- if args .NoProxy == unsetProxyArgValue {
256+ case args .NoProxy == unsetProxyArgValue :
237257 return fmt .Errorf ("argument no_proxy is required but can be an empty value" )
238258 }
239- // validation responsibility lives in driver itself, so we do not validate all other args
240259 return nil
241260}
242261
@@ -251,18 +270,17 @@ func extractOutputParameters(execution *driver.Execution, driverType string) []a
251270 Value : fmt .Sprint (execution .ID ),
252271 })
253272 }
254- if execution .IterationCount != nil {
273+ switch {
274+ case execution .IterationCount != nil :
255275 outputs = append (outputs , api.Parameter {
256276 Name : "iteration-count" ,
257- Value : fmt .Sprint (execution .IterationCount ),
277+ Value : fmt .Sprint (* execution .IterationCount ),
278+ })
279+ case driverType == RootDag :
280+ outputs = append (outputs , api.Parameter {
281+ Name : "iteration-count" ,
282+ Value : "0" ,
258283 })
259- } else {
260- if driverType == ROOT_DAG {
261- outputs = append (outputs , api.Parameter {
262- Name : "iteration-count" ,
263- Value : fmt .Sprint (0 ),
264- })
265- }
266284 }
267285 if execution .Cached != nil {
268286 outputs = append (outputs , api.Parameter {
0 commit comments