Skip to content

Commit c841d7a

Browse files
author
arpechenin
committed
- merge tests
- fix formatting Signed-off-by: arpechenin <[email protected]>
1 parent e9955cf commit c841d7a

File tree

145 files changed

+7556
-15621
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

145 files changed

+7556
-15621
lines changed

backend/src/driver/api/request.go

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
// Copyright 2025 The Kubeflow Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
// Package api provides HTTP DTOs used by the driver server.
116
package api
217

318
type DriverPluginArgs struct {
@@ -6,8 +21,8 @@ type DriverPluginArgs struct {
621
Container string `json:"container,omitempty"`
722
DagExecutionID string `json:"dag_execution_id"`
823
IterationIndex string `json:"iteration_index"`
9-
HttpProxy string `json:"http_proxy"`
10-
HttpsProxy string `json:"https_proxy"`
24+
HTTPProxy string `json:"http_proxy"`
25+
HTTPSProxy string `json:"https_proxy"`
1126
NoProxy string `json:"no_proxy"`
1227
KubernetesConfig string `json:"kubernetes_config,omitempty"`
1328
RuntimeConfig string `json:"runtime_config,omitempty"`
@@ -20,7 +35,7 @@ type DriverPluginArgs struct {
2035
Task string `json:"task"`
2136
Type string `json:"type"`
2237
CacheDisabledFlag bool `json:"cache_disabled"`
23-
ExecutionIdPath string `json:"execution_id_path"`
38+
ExecutionIDPath string `json:"execution_id_path"`
2439
IterationCountPath string `json:"iteration_count_path"`
2540
ConditionPath string `json:"condition_path"`
2641
PodSpecPathPath string `json:"pod_spec_patch_path"`

backend/src/driver/api/response.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
// Copyright 2025 The Kubeflow Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
115
package api
216

317
type DriverResponse struct {

backend/src/driver/execution_paths.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
// Copyright 2025 The Kubeflow Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
115
package main
216

317
type ExecutionPaths struct {

backend/src/driver/main.go

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
14+
1415
package main
1516

1617
import (
@@ -37,7 +38,7 @@ import (
3738

3839
const (
3940
unsetProxyArgValue = "unset"
40-
ROOT_DAG = "ROOT_DAG"
41+
RootDag = "ROOT_DAG"
4142
DAG = "DAG"
4243
CONTAINER = "CONTAINER"
4344
)
@@ -76,13 +77,13 @@ func init() {
7677
flag.Set("stderrthreshold", "WARNING")
7778
}
7879

79-
func parseExecConfigJson(k8sExecConfigJson *string) (*kubernetesplatform.KubernetesExecutorConfig, error) {
80+
func parseExecConfigJSON(k8sExecConfigJSON *string) (*kubernetesplatform.KubernetesExecutorConfig, error) {
8081
var k8sExecCfg *kubernetesplatform.KubernetesExecutorConfig
81-
if *k8sExecConfigJson != "" {
82-
glog.Infof("input kubernetesConfig:%s\n", prettyPrint(*k8sExecConfigJson))
82+
if *k8sExecConfigJSON != "" {
83+
glog.Infof("input kubernetesConfig:%s\n", prettyPrint(*k8sExecConfigJSON))
8384
k8sExecCfg = &kubernetesplatform.KubernetesExecutorConfig{}
84-
if err := util.UnmarshalString(*k8sExecConfigJson, k8sExecCfg); err != nil {
85-
return nil, fmt.Errorf("failed to unmarshal Kubernetes config, error: %w\nKubernetesConfig: %v", err, k8sExecConfigJson)
85+
if err := util.UnmarshalString(*k8sExecConfigJSON, k8sExecCfg); err != nil {
86+
return nil, fmt.Errorf("failed to unmarshal Kubernetes config, error: %w\nKubernetesConfig: %v", err, k8sExecConfigJSON)
8687
}
8788
}
8889
return k8sExecCfg, nil
@@ -102,7 +103,7 @@ func handleExecution(execution *driver.Execution, driverType string, executionPa
102103
return fmt.Errorf("failed to write iteration count to file: %w", err)
103104
}
104105
} else {
105-
if driverType == ROOT_DAG {
106+
if driverType == RootDag {
106107
if err := writeFile(executionPaths.IterationCount, []byte("0")); err != nil {
107108
return fmt.Errorf("failed to write iteration count to file: %w", err)
108109
}
@@ -119,7 +120,7 @@ func handleExecution(execution *driver.Execution, driverType string, executionPa
119120
}
120121
} else {
121122
// nil is a valid value for Condition
122-
if driverType == ROOT_DAG || driverType == CONTAINER {
123+
if driverType == RootDag || driverType == CONTAINER {
123124
if err := writeFile(executionPaths.Condition, []byte("nil")); err != nil {
124125
return fmt.Errorf("failed to write condition to file: %w", err)
125126
}

backend/src/driver/main_test.go

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
// Copyright 2025 The Kubeflow Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
115
package main
216

317
import (
@@ -45,7 +59,7 @@ func TestSpecParsing(t *testing.T) {
4559

4660
for _, tc := range tt {
4761
t.Logf("Running test case: %s", tc.name)
48-
cfg, err := parseExecConfigJson(tc.input)
62+
cfg, err := parseExecConfigJSON(tc.input)
4963
assert.Equal(t, tc.wantErr, err != nil)
5064
assert.True(t, proto.Equal(tc.expected, cfg))
5165
}
@@ -77,7 +91,7 @@ func Test_handleExecutionRootDAG(t *testing.T) {
7791
Condition: "condition.txt",
7892
}
7993

80-
err := handleExecution(execution, ROOT_DAG, executionPaths)
94+
err := handleExecution(execution, RootDag, executionPaths)
8195

8296
if err != nil {
8397
t.Errorf("Unexpected error: %v", err)

backend/src/driver/rpc_handler.go

Lines changed: 56 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
// Copyright 2025 The Kubeflow Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
115
package main
216

317
import (
@@ -94,13 +108,14 @@ func parseDriverRequestArgs(r *http.Request) (*api.DriverPluginArgs, error) {
94108
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
95109
return nil, fmt.Errorf("failed to parse driver request body: %v", err)
96110
}
97-
if body.Template == nil {
111+
switch {
112+
case body.Template == nil:
98113
return nil, fmt.Errorf("driver request body.Template is empty")
99-
} else if body.Template.Plugin == nil {
114+
case body.Template.Plugin == nil:
100115
return nil, fmt.Errorf("driver request body.Template.Plugin is empty")
101-
} else if body.Template.Plugin.DriverPlugin == nil {
116+
case body.Template.Plugin.DriverPlugin == nil:
102117
return nil, fmt.Errorf("driver request body.Template.Plugin.DriverPlugin is empty")
103-
} else if body.Template.Plugin.DriverPlugin.Args == nil {
118+
case body.Template.Plugin.DriverPlugin.Args == nil:
104119
return nil, fmt.Errorf("driver request body.Template.Plugin.Args is empty")
105120
}
106121
args := body.Template.Plugin.DriverPlugin.Args
@@ -118,7 +133,7 @@ func drive(args api.DriverPluginArgs) (execution *driver.Execution, err error) {
118133
}()
119134
ctx := context.Background()
120135

121-
proxy.InitializeConfig(args.HttpProxy, args.HttpsProxy, args.NoProxy)
136+
proxy.InitializeConfig(args.HTTPProxy, args.HTTPSProxy, args.NoProxy)
122137

123138
glog.Infof("input ComponentSpec:%s\n", prettyPrint(args.Component))
124139
componentSpec := &pipelinespec.ComponentSpec{}
@@ -149,7 +164,7 @@ func drive(args api.DriverPluginArgs) (execution *driver.Execution, err error) {
149164
return nil, fmt.Errorf("failed to unmarshal runtime config, error: %w\nruntimeConfig: %v", err, args.RuntimeConfig)
150165
}
151166
}
152-
k8sExecCfg, err := parseExecConfigJson(&args.KubernetesConfig)
167+
k8sExecCfg, err := parseExecConfigJSON(&args.KubernetesConfig)
153168
if err != nil {
154169
return nil, err
155170
}
@@ -160,7 +175,9 @@ func drive(args api.DriverPluginArgs) (execution *driver.Execution, err error) {
160175
var tlsCfg *tls.Config
161176
if args.MetadataTLSEnabled {
162177
tlsCfg, err = util.GetTLSConfig(args.CACertPath)
163-
return nil, fmt.Errorf("unable to drive driver: failed to load TLS configuration: %v", err)
178+
if err != nil {
179+
return nil, fmt.Errorf("unable to drive driver: failed to load TLS configuration: %v", err)
180+
}
164181
}
165182
client, err := newMlmdClient(tlsCfg)
166183
if err != nil {
@@ -180,23 +197,28 @@ func drive(args api.DriverPluginArgs) (execution *driver.Execution, err error) {
180197
return nil, fmt.Errorf("failed to parse iteration index, error: %w", err)
181198
}
182199
options := driver.Options{
183-
PipelineName: args.PipelineName,
184-
RunID: args.RunID,
185-
RunName: args.RunName,
186-
RunDisplayName: args.RunDisplayName,
187-
Namespace: namespace,
188-
Component: componentSpec,
189-
Task: taskSpec,
190-
DAGExecutionID: dagExecutionID,
191-
IterationIndex: iterationIndex,
192-
PublishLogs: args.PublishLogs,
193-
CacheDisabled: args.CacheDisabledFlag,
194-
DriverType: args.Type,
195-
TaskName: args.TaskName,
200+
PipelineName: args.PipelineName,
201+
RunID: args.RunID,
202+
RunName: args.RunName,
203+
RunDisplayName: args.RunDisplayName,
204+
Namespace: namespace,
205+
Component: componentSpec,
206+
Task: taskSpec,
207+
DAGExecutionID: dagExecutionID,
208+
IterationIndex: iterationIndex,
209+
PublishLogs: args.PublishLogs,
210+
CacheDisabled: args.CacheDisabledFlag,
211+
DriverType: args.Type,
212+
TaskName: args.TaskName,
213+
MLPipelineTLSEnabled: args.MlPipelineTLSEnabled,
214+
MLMDServerAddress: *mlmdServerAddress,
215+
MLMDServerPort: *mlmdServerPort,
216+
CaCertPath: args.CACertPath,
196217
}
218+
197219
var driverErr error
198220
switch args.Type {
199-
case ROOT_DAG:
221+
case RootDag:
200222
options.RuntimeConfig = runtimeConfig
201223
execution, driverErr = driver.RootDAG(ctx, options, client)
202224
case DAG:
@@ -224,19 +246,16 @@ func drive(args api.DriverPluginArgs) (execution *driver.Execution, err error) {
224246
}
225247

226248
func validate(args api.DriverPluginArgs) error {
227-
if args.Type == "" {
249+
switch {
250+
case args.Type == "":
228251
return fmt.Errorf("argument type must be specified")
229-
}
230-
if args.HttpProxy == unsetProxyArgValue {
252+
case args.HTTPProxy == unsetProxyArgValue:
231253
return fmt.Errorf("argument http_proxy is required but can be an empty value")
232-
}
233-
if args.HttpsProxy == unsetProxyArgValue {
254+
case args.HTTPSProxy == unsetProxyArgValue:
234255
return fmt.Errorf("argument https_proxy is required but can be an empty value")
235-
}
236-
if args.NoProxy == unsetProxyArgValue {
256+
case args.NoProxy == unsetProxyArgValue:
237257
return fmt.Errorf("argument no_proxy is required but can be an empty value")
238258
}
239-
// validation responsibility lives in driver itself, so we do not validate all other args
240259
return nil
241260
}
242261

@@ -251,18 +270,17 @@ func extractOutputParameters(execution *driver.Execution, driverType string) []a
251270
Value: fmt.Sprint(execution.ID),
252271
})
253272
}
254-
if execution.IterationCount != nil {
273+
switch {
274+
case execution.IterationCount != nil:
255275
outputs = append(outputs, api.Parameter{
256276
Name: "iteration-count",
257-
Value: fmt.Sprint(execution.IterationCount),
277+
Value: fmt.Sprint(*execution.IterationCount),
278+
})
279+
case driverType == RootDag:
280+
outputs = append(outputs, api.Parameter{
281+
Name: "iteration-count",
282+
Value: "0",
258283
})
259-
} else {
260-
if driverType == ROOT_DAG {
261-
outputs = append(outputs, api.Parameter{
262-
Name: "iteration-count",
263-
Value: fmt.Sprint(0),
264-
})
265-
}
266284
}
267285
if execution.Cached != nil {
268286
outputs = append(outputs, api.Parameter{

0 commit comments

Comments
 (0)