Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions deployment/promotion/promote.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
import os
import datetime

from upload_to_s3 import upload_to_s3, get_base_path_to_artifact, update_artifact_path

env_vars = ['VERTA_SOURCE_MODEL_VERSION_ID', 'VERTA_SOURCE_HOST', 'VERTA_SOURCE_EMAIL',
'VERTA_SOURCE_DEV_KEY',
'VERTA_SOURCE_WORKSPACE_0', 'VERTA_DEST_HOST', 'VERTA_DEST_EMAIL',
Expand Down Expand Up @@ -66,6 +68,7 @@
'grpc-metadata-developer_key': os.environ.get('VERTA_SOURCE_DEV_KEY')}
workspaces_source = requests.get(host, headers=headers_dict, proxies=proxies)

source_workspace = None
for item in workspaces_source.json()['workspace']:
if 'id' in item.keys() and item['id'] == source_workspace_id:
if 'org_name' in item.keys():
Expand All @@ -81,6 +84,8 @@
os.environ['VERTA_DEST_WORKSPACE'] = item['org_name']
elif 'username' in item.keys() and item['username'] == source_workspace:
os.environ['VERTA_DEST_WORKSPACE'] = item['username']
else:
source_workspace = os.environ.get('VERTA_DEST_WORKSPACE')

for param_name in env_vars:
param = os.environ.get(param_name)
Expand All @@ -101,7 +106,7 @@

config = {
'source': {
'model_version_id': atoi(params['VERTA_SOURCE_MODEL_VERSION_ID'][2:-2]),
'model_version_id': atoi(params['VERTA_SOURCE_MODEL_VERSION_ID']), # [2:-2]),
'host': params['VERTA_SOURCE_HOST'],
'email': params['VERTA_SOURCE_EMAIL'],
'devkey': params['VERTA_SOURCE_DEV_KEY'],
Expand Down Expand Up @@ -233,8 +238,11 @@ def download_artifact(auth, model_version_id, artifact):
key = artifact['key']
url = signed_artifact_url(auth, model_version_id, artifact)
print("Downloading artifact '%s'" % key)
curl_cmd = "curl --cacert %s -o %s %s '%s'" % (
os.environ['REQUESTS_CA_BUNDLE'], key, params['VERTA_CURL_OPTS'], url)
if 'REQUESTS_CA_BUNDLE' not in os.environ:
curl_cmd = "curl -o %s %s '%s'" % (key, params['VERTA_CURL_OPTS'], url)
else:
curl_cmd = "curl --cacert %s -o %s %s '%s'" % (
os.environ['REQUESTS_CA_BUNDLE'], key, params['VERTA_CURL_OPTS'], url)
os.system(curl_cmd)


Expand All @@ -250,7 +258,7 @@ def download_artifacts(auth, model_version_id, artifacts, model_artifact):
copy_fields(['artifact_type', 'key'], artifact, artifact_request)
download_artifact(auth, model_version_id, artifact_request)
downloaded_artifacts.append(
{'key': artifact['key'], 'artifact_type': artifact['artifact_type']})
{'key': artifact['key'], 'artifact_type': artifact['artifact_type'], 'path': artifact['path'], 'filename_extension': artifact['filename_extension']})

model_artifact_request = {
'method': 'GET',
Expand All @@ -264,6 +272,10 @@ def download_artifacts(auth, model_version_id, artifacts, model_artifact):

def upload_artifact(auth, model_version_id, artifact):
key = artifact['key']
print("Uploading artifact '%s' to s3" % key)

upload_to_s3(artifact, "")

print("Uploading artifact '%s'" % key)
print(artifact)

Expand All @@ -284,8 +296,8 @@ def upload_artifact(auth, model_version_id, artifact):

if not put_response.ok:
raise Exception("Failed to put artifact (%d %s). Key: %s\tURL: %s\tText: %s" % (
put_response.status_code,
put_response.reason, key, put_url, put_response.text))
put_response.status_code,
put_response.reason, key, put_url, put_response.text))

check_url = signed_artifact_url(auth, model_version_id,
{'method': 'GET', 'model_version_id': model_version_id,
Expand All @@ -304,6 +316,12 @@ def upload_artifacts(auth, model_version_id, artifacts):
print("Uploading %d artifacts" % len(artifacts))
uploaded_artifacts = {}

# Update paths
base_path = get_base_path_to_artifact(artifacts)
for artifact in artifacts:
update_artifact_path(artifact, base_path)
print(f"Updated artifact path to {artifact['path']}")

for artifact in artifacts:
uploaded_artifacts[artifact["key"]] = upload_artifact(auth, model_version_id, artifact)
return uploaded_artifacts
Expand Down
53 changes: 53 additions & 0 deletions deployment/promotion/upload_to_s3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import boto3
import os

# Set your AWS credentials (replace with your own values)
# aws_access_key = os.environ['AWS_ACCESS_KEY_ID']
# aws_secret_key = os.environ['AWS_SECRET_ACCESS_KEY']
# aws_region = 'us-east-1' # Change to your desired region


def get_base_path_to_artifact(artifacts):
if len(artifacts) == 0:
# This shouldn't happen often, if ever.
print("No artifacts found; cannot construct path.")
return None
path = artifacts[0]['path']
split_path = str.split(path, '/')
path = '/'.join(split_path[1:-1])
path = path.replace("hmpreprod", "hm", 1)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is terribly hacky and I'm not entirely sure it'll work

return path


def get_key_with_extension(artifact):
"""Return artifact key with extension. Add one if not already present."""
key = artifact['key']
if '.' in key:
return key
filename_extension = artifact['filename_extension']
return f"{key}.{filename_extension}"


def update_artifact_path(artifact, base_path):
artifact['path'] = f"{base_path}"


def upload_to_s3(artifact, s3_url):
# Connect to s3
session = boto3.Session()
s3 = session.client('s3')

print(f"about to upload artifact {artifact}")

og_key = artifact['key']
key_with_extension = get_key_with_extension(artifact)
path = artifact['path']
bucket_name = 'vertaai-user-data-dev-us-east-1'
s3_key = f"testing-hm-preprod/{path}/{key_with_extension}"

# Upload the file to S3
s3.upload_file(og_key, bucket_name, s3_key)
print(f"File '{og_key}' uploaded to '{s3_key}' in bucket '{bucket_name}'.")