Skip to content

Commit 4ff5205

Browse files
committed
more tests
Signed-off-by: oliver könig <[email protected]>
1 parent 94139b6 commit 4ff5205

File tree

2 files changed

+110
-15
lines changed

2 files changed

+110
-15
lines changed

nemo_run/core/execution/dgxcloud.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -368,8 +368,8 @@ def fetch_logs(
368368
) -> Iterable[str]:
369369
token = self.get_auth_token()
370370
if not token:
371-
logger.error("Failed to retrieve auth token for cancellation request.")
372-
return
371+
logger.error("Failed to retrieve auth token for fetch logs request.")
372+
yield ""
373373

374374
response = requests.get(
375375
f"{self.base_url}/workloads", headers=self._default_headers(token=token)
@@ -383,6 +383,7 @@ def fetch_logs(
383383
None,
384384
)
385385
if workload_name is None:
386+
logger.error(f"No workload found with id {job_id}")
386387
yield ""
387388

388389
urls = [
@@ -415,7 +416,15 @@ def fetch_logs(
415416
# Yield chunks as they arrive
416417
while active_urls:
417418
url, item = q.get()
418-
if item is None:
419+
if item is None or self.status(job_id) in [
420+
DGXCloudState.DELETING,
421+
DGXCloudState.STOPPED,
422+
DGXCloudState.STOPPING,
423+
DGXCloudState.DEGRADED,
424+
DGXCloudState.FAILED,
425+
DGXCloudState.COMPLETED,
426+
DGXCloudState.TERMINATING,
427+
]:
419428
active_urls.discard(url)
420429
else:
421430
yield item

test/core/execution/test_dgxcloud.py

Lines changed: 98 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import logging
1617
import os
18+
import queue
1719
import subprocess
1820
import tempfile
1921
from unittest.mock import MagicMock, mock_open, patch
@@ -101,8 +103,67 @@ def test_get_auth_token_failure(self, mock_post):
101103

102104
assert token is None
103105

106+
def test_fetch_no_token(self, caplog):
107+
with (
108+
patch.object(DGXCloudExecutor, "get_auth_token", return_value=None),
109+
caplog.at_level(logging.ERROR),
110+
):
111+
executor = DGXCloudExecutor(
112+
base_url="https://dgxapi.example.com",
113+
kube_apiserver_url="https://127.0.0.1:443",
114+
app_id="test_app_id",
115+
app_secret="test_app_secret",
116+
project_name="test_project",
117+
container_image="nvcr.io/nvidia/test:latest",
118+
pvc_nemo_run_dir="/workspace/nemo_run",
119+
)
120+
121+
logs_iter = executor.fetch_logs("123", stream=True)
122+
assert next(logs_iter) == ""
123+
assert (
124+
caplog.records[-1].message
125+
== "Failed to retrieve auth token for fetch logs request."
126+
)
127+
assert caplog.records[-1].levelname == "ERROR"
128+
caplog.clear()
129+
104130
@patch("nemo_run.core.execution.dgxcloud.requests.get")
105-
def test_fetch_logs(self, mock_requests_get):
131+
def test_fetch_no_workload_with_name(self, mock_requests_get, caplog):
132+
mock_workloads_response = MagicMock(spec=requests.Response)
133+
mock_workloads_response.json.return_value = {
134+
"workloads": [{"name": "hello-world", "id": "123"}]
135+
}
136+
137+
mock_requests_get.side_effect = [mock_workloads_response]
138+
139+
with (
140+
patch.object(DGXCloudExecutor, "get_auth_token", return_value="test_token"),
141+
caplog.at_level(logging.ERROR),
142+
):
143+
executor = DGXCloudExecutor(
144+
base_url="https://dgxapi.example.com",
145+
kube_apiserver_url="https://127.0.0.1:443",
146+
app_id="test_app_id",
147+
app_secret="test_app_secret",
148+
project_name="test_project",
149+
container_image="nvcr.io/nvidia/test:latest",
150+
pvc_nemo_run_dir="/workspace/nemo_run",
151+
)
152+
153+
logs_iter = executor.fetch_logs("this-workload-does-not-exist", stream=True)
154+
assert next(logs_iter) == ""
155+
assert (
156+
caplog.records[-1].message
157+
== "No workload found with id this-workload-does-not-exist"
158+
)
159+
assert caplog.records[-1].levelname == "ERROR"
160+
caplog.clear()
161+
162+
@patch("nemo_run.core.execution.dgxcloud.requests.get")
163+
@patch("nemo_run.core.execution.dgxcloud.time.sleep")
164+
@patch("nemo_run.core.execution.dgxcloud.threading.Thread")
165+
@patch("nemo_run.core.execution.dgxcloud.queue.Queue")
166+
def test_fetch_logs(self, mock_queue, mock_threading_Thread, mock_sleep, mock_requests_get):
106167
# --- 1. Setup Primitives for the *live* test ---
107168
mock_log_response = MagicMock(spec=requests.Response)
108169

@@ -117,12 +178,29 @@ def test_fetch_logs(self, mock_requests_get):
117178
"workloads": [{"name": "hello-world", "id": "123"}]
118179
}
119180

181+
mock_queue_instance = MagicMock()
182+
mock_queue_instance.get.side_effect = [
183+
(
184+
"https://127.0.0.1:443/api/v1/namespaces/runai-test_project/pods/hello-world-worker-0/log?container=pytorch&follow=true",
185+
"this is a static log\n",
186+
),
187+
(
188+
"https://127.0.0.1:443/api/v1/namespaces/runai-test_project/pods/hello-world-worker-0/log?container=pytorch&follow=true",
189+
None,
190+
),
191+
(
192+
"https://127.0.0.1:443/api/v1/namespaces/runai-test_project/pods/hello-world-worker-1/log?container=pytorch&follow=true",
193+
None,
194+
),
195+
]
196+
120197
mock_requests_get.side_effect = [mock_workloads_response, mock_log_response]
121198

122199
# --- 4. Setup Executor (inside the patch) ---
123200
with (
124201
patch.object(DGXCloudExecutor, "get_auth_token", return_value="test_token"),
125202
patch.object(DGXCloudExecutor, "status", return_value=DGXCloudState.RUNNING),
203+
patch("nemo_run.core.execution.dgxcloud.queue.Queue", return_value=mock_queue_instance),
126204
):
127205
executor = DGXCloudExecutor(
128206
base_url="https://dgxapi.example.com",
@@ -132,24 +210,31 @@ def test_fetch_logs(self, mock_requests_get):
132210
project_name="test_project",
133211
container_image="nvcr.io/nvidia/test:latest",
134212
pvc_nemo_run_dir="/workspace/nemo_run",
213+
nodes=2,
135214
)
136215

137216
logs_iter = executor.fetch_logs("123", stream=True)
138217

139218
assert next(logs_iter) == "this is a static log\n"
140-
assert next(logs_iter) == "this is the last static log\n"
141219

142-
mock_requests_get.assert_any_call(
143-
"https://127.0.0.1:443/api/v1/namespaces/runai-test_project/pods/hello-world-worker-0/log?container=pytorch&follow=true",
144-
headers={
145-
"Accept": "application/json",
146-
"Content-Type": "application/json",
147-
"Authorization": "Bearer test_token",
148-
},
149-
verify=False,
150-
stream=True,
151-
)
220+
mock_sleep.assert_called_once_with(10)
152221

222+
mock_threading_Thread.assert_any_call(
223+
target=executor._stream_url_sync,
224+
args=(
225+
"https://127.0.0.1:443/api/v1/namespaces/runai-test_project/pods/hello-world-worker-0/log?container=pytorch&follow=true",
226+
executor._default_headers(token="test_token"),
227+
mock_queue_instance,
228+
),
229+
)
230+
mock_threading_Thread.assert_any_call(
231+
target=executor._stream_url_sync,
232+
args=(
233+
"https://127.0.0.1:443/api/v1/namespaces/runai-test_project/pods/hello-world-worker-1/log?container=pytorch&follow=true",
234+
executor._default_headers(token="test_token"),
235+
mock_queue_instance,
236+
),
237+
)
153238
with pytest.raises(StopIteration):
154239
next(logs_iter)
155240

@@ -1007,3 +1092,4 @@ def test_default_headers_with_token(self):
10071092
assert headers["Content-Type"] == "application/json"
10081093
assert "Authorization" in headers
10091094
assert headers["Authorization"] == "Bearer test_token"
1095+
assert headers["Authorization"] == "Bearer test_token"

0 commit comments

Comments
 (0)