|
6 | 6 | from threading import RLock
|
7 | 7 | from uuid import uuid4
|
8 | 8 |
|
| 9 | +from azure.common.credentials import get_cli_profile |
9 | 10 | from azure.core.exceptions import ResourceNotFoundError
|
10 | 11 | from azure.identity import DefaultAzureCredential
|
11 | 12 | from azure.mgmt.compute import ComputeManagementClient
|
@@ -65,7 +66,17 @@ class AzureNodeProvider(NodeProvider):
|
65 | 66 |
|
66 | 67 | def __init__(self, provider_config, cluster_name):
|
67 | 68 | NodeProvider.__init__(self, provider_config, cluster_name)
|
68 |
| - subscription_id = provider_config["subscription_id"] |
| 69 | + subscription_id = provider_config.get("subscription_id") |
| 70 | + if subscription_id is None: |
| 71 | + # Get subscription from logged in azure profile |
| 72 | + # if it isn't provided in the provider_config |
| 73 | + # so operations like `get-head-ip` will work |
| 74 | + subscription_id = get_cli_profile().get_subscription_id() |
| 75 | + provider_config["subscription_id"] = subscription_id |
| 76 | + logger.info( |
| 77 | + "subscription_id not found in provider config, falling back " |
| 78 | + f"to subscription_id from the logged in azure profile: {subscription_id}" |
| 79 | + ) |
69 | 80 | self.cache_stopped_nodes = provider_config.get("cache_stopped_nodes", True)
|
70 | 81 | credential = DefaultAzureCredential(exclude_shared_token_cache_credential=True)
|
71 | 82 | self.compute_client = ComputeManagementClient(credential, subscription_id)
|
|
0 commit comments