1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16+ import logging
1617import os
18+ import queue
1719import subprocess
1820import tempfile
1921from unittest .mock import MagicMock , mock_open , patch
@@ -101,8 +103,67 @@ def test_get_auth_token_failure(self, mock_post):
101103
102104 assert token is None
103105
106+ def test_fetch_no_token (self , caplog ):
107+ with (
108+ patch .object (DGXCloudExecutor , "get_auth_token" , return_value = None ),
109+ caplog .at_level (logging .ERROR ),
110+ ):
111+ executor = DGXCloudExecutor (
112+ base_url = "https://dgxapi.example.com" ,
113+ kube_apiserver_url = "https://127.0.0.1:443" ,
114+ app_id = "test_app_id" ,
115+ app_secret = "test_app_secret" ,
116+ project_name = "test_project" ,
117+ container_image = "nvcr.io/nvidia/test:latest" ,
118+ pvc_nemo_run_dir = "/workspace/nemo_run" ,
119+ )
120+
121+ logs_iter = executor .fetch_logs ("123" , stream = True )
122+ assert next (logs_iter ) == ""
123+ assert (
124+ caplog .records [- 1 ].message
125+ == "Failed to retrieve auth token for fetch logs request."
126+ )
127+ assert caplog .records [- 1 ].levelname == "ERROR"
128+ caplog .clear ()
129+
104130 @patch ("nemo_run.core.execution.dgxcloud.requests.get" )
105- def test_fetch_logs (self , mock_requests_get ):
131+ def test_fetch_no_workload_with_name (self , mock_requests_get , caplog ):
132+ mock_workloads_response = MagicMock (spec = requests .Response )
133+ mock_workloads_response .json .return_value = {
134+ "workloads" : [{"name" : "hello-world" , "id" : "123" }]
135+ }
136+
137+ mock_requests_get .side_effect = [mock_workloads_response ]
138+
139+ with (
140+ patch .object (DGXCloudExecutor , "get_auth_token" , return_value = "test_token" ),
141+ caplog .at_level (logging .ERROR ),
142+ ):
143+ executor = DGXCloudExecutor (
144+ base_url = "https://dgxapi.example.com" ,
145+ kube_apiserver_url = "https://127.0.0.1:443" ,
146+ app_id = "test_app_id" ,
147+ app_secret = "test_app_secret" ,
148+ project_name = "test_project" ,
149+ container_image = "nvcr.io/nvidia/test:latest" ,
150+ pvc_nemo_run_dir = "/workspace/nemo_run" ,
151+ )
152+
153+ logs_iter = executor .fetch_logs ("this-workload-does-not-exist" , stream = True )
154+ assert next (logs_iter ) == ""
155+ assert (
156+ caplog .records [- 1 ].message
157+ == "No workload found with id this-workload-does-not-exist"
158+ )
159+ assert caplog .records [- 1 ].levelname == "ERROR"
160+ caplog .clear ()
161+
162+ @patch ("nemo_run.core.execution.dgxcloud.requests.get" )
163+ @patch ("nemo_run.core.execution.dgxcloud.time.sleep" )
164+ @patch ("nemo_run.core.execution.dgxcloud.threading.Thread" )
165+ @patch ("nemo_run.core.execution.dgxcloud.queue.Queue" )
166+ def test_fetch_logs (self , mock_queue , mock_threading_Thread , mock_sleep , mock_requests_get ):
106167 # --- 1. Setup Primitives for the *live* test ---
107168 mock_log_response = MagicMock (spec = requests .Response )
108169
@@ -117,12 +178,29 @@ def test_fetch_logs(self, mock_requests_get):
117178 "workloads" : [{"name" : "hello-world" , "id" : "123" }]
118179 }
119180
181+ mock_queue_instance = MagicMock ()
182+ mock_queue_instance .get .side_effect = [
183+ (
184+ "https://127.0.0.1:443/api/v1/namespaces/runai-test_project/pods/hello-world-worker-0/log?container=pytorch&follow=true" ,
185+ "this is a static log\n " ,
186+ ),
187+ (
188+ "https://127.0.0.1:443/api/v1/namespaces/runai-test_project/pods/hello-world-worker-0/log?container=pytorch&follow=true" ,
189+ None ,
190+ ),
191+ (
192+ "https://127.0.0.1:443/api/v1/namespaces/runai-test_project/pods/hello-world-worker-1/log?container=pytorch&follow=true" ,
193+ None ,
194+ ),
195+ ]
196+
120197 mock_requests_get .side_effect = [mock_workloads_response , mock_log_response ]
121198
122199 # --- 4. Setup Executor (inside the patch) ---
123200 with (
124201 patch .object (DGXCloudExecutor , "get_auth_token" , return_value = "test_token" ),
125202 patch .object (DGXCloudExecutor , "status" , return_value = DGXCloudState .RUNNING ),
203+ patch ("nemo_run.core.execution.dgxcloud.queue.Queue" , return_value = mock_queue_instance ),
126204 ):
127205 executor = DGXCloudExecutor (
128206 base_url = "https://dgxapi.example.com" ,
@@ -132,24 +210,31 @@ def test_fetch_logs(self, mock_requests_get):
132210 project_name = "test_project" ,
133211 container_image = "nvcr.io/nvidia/test:latest" ,
134212 pvc_nemo_run_dir = "/workspace/nemo_run" ,
213+ nodes = 2 ,
135214 )
136215
137216 logs_iter = executor .fetch_logs ("123" , stream = True )
138217
139218 assert next (logs_iter ) == "this is a static log\n "
140- assert next (logs_iter ) == "this is the last static log\n "
141219
142- mock_requests_get .assert_any_call (
143- "https://127.0.0.1:443/api/v1/namespaces/runai-test_project/pods/hello-world-worker-0/log?container=pytorch&follow=true" ,
144- headers = {
145- "Accept" : "application/json" ,
146- "Content-Type" : "application/json" ,
147- "Authorization" : "Bearer test_token" ,
148- },
149- verify = False ,
150- stream = True ,
151- )
220+ mock_sleep .assert_called_once_with (10 )
152221
222+ mock_threading_Thread .assert_any_call (
223+ target = executor ._stream_url_sync ,
224+ args = (
225+ "https://127.0.0.1:443/api/v1/namespaces/runai-test_project/pods/hello-world-worker-0/log?container=pytorch&follow=true" ,
226+ executor ._default_headers (token = "test_token" ),
227+ mock_queue_instance ,
228+ ),
229+ )
230+ mock_threading_Thread .assert_any_call (
231+ target = executor ._stream_url_sync ,
232+ args = (
233+ "https://127.0.0.1:443/api/v1/namespaces/runai-test_project/pods/hello-world-worker-1/log?container=pytorch&follow=true" ,
234+ executor ._default_headers (token = "test_token" ),
235+ mock_queue_instance ,
236+ ),
237+ )
153238 with pytest .raises (StopIteration ):
154239 next (logs_iter )
155240
@@ -1007,3 +1092,4 @@ def test_default_headers_with_token(self):
10071092 assert headers ["Content-Type" ] == "application/json"
10081093 assert "Authorization" in headers
10091094 assert headers ["Authorization" ] == "Bearer test_token"
1095+ assert headers ["Authorization" ] == "Bearer test_token"
0 commit comments