Skip to content

Commit d7b135c

Browse files
Merge pull request #33 from Axiomatic-AI/tools-helper-client
Add tool hosting object parsing
2 parents 93b4f7c + 9d833dd commit d7b135c

File tree

1 file changed

+37
-5
lines changed

1 file changed

+37
-5
lines changed

src/axiomatic/client.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import base64
2+
import dill # type: ignore
3+
import json
24
import requests
35
import os
46
import time
@@ -192,12 +194,12 @@ def tool_exec(self, tool: str, code: str, poll_interval: int = 3, debug: bool =
192194
tool_name = tool.strip()
193195
code_string = code
194196

195-
output = self._ax_client.tools.schedule(
197+
tool_result = self._ax_client.tools.schedule(
196198
tool_name=tool_name,
197199
code=code_string,
198200
)
199-
if output.is_success is True:
200-
job_id = str(output.job_id)
201+
if tool_result.is_success is True:
202+
job_id = str(tool_result.job_id)
201203
result = self._ax_client.tools.status(job_id=job_id)
202204
if debug:
203205
print(f"job_id: {job_id}")
@@ -211,11 +213,41 @@ def tool_exec(self, tool: str, code: str, poll_interval: int = 3, debug: bool =
211213
if debug:
212214
print(f"status: {result.status}")
213215
if result.status == "SUCCEEDED":
214-
return result.output
216+
output = json.loads(result.output or "{}")
217+
if not output['objects']:
218+
return result.output
219+
else:
220+
return {
221+
"messages": output['messages'],
222+
"objects": self._load_objects_from_base64(output['objects'])
223+
}
215224
else:
216225
return result.error_trace
217226
else:
218-
return output.error_trace
227+
return tool_result.error_trace
228+
229+
def load(self, job_id: str, obj_key: str):
230+
result = self._ax_client.tools.status(job_id=job_id)
231+
if result.status == "SUCCEEDED":
232+
output = json.loads(result.output or "{}")
233+
if not output['objects']:
234+
return result.output
235+
else:
236+
return self._load_objects_from_base64(output['objects'])[obj_key]
237+
else:
238+
return result.error_trace
239+
240+
def _load_objects_from_base64(self, encoded_dict):
241+
loaded_objects = {}
242+
for key, encoded_str in encoded_dict.items():
243+
try:
244+
decoded_bytes = base64.b64decode(encoded_str)
245+
loaded_obj = dill.loads(decoded_bytes)
246+
loaded_objects[key] = loaded_obj
247+
except Exception as e:
248+
print(f"Error loading object for key '{key}': {e}")
249+
loaded_objects[key] = None
250+
return loaded_objects
219251

220252

221253
class AsyncAxiomatic(AsyncBaseClient): ...

0 commit comments

Comments
 (0)