Skip to content
27 changes: 20 additions & 7 deletions cog_safe_push/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ def cog_safe_push(
raise

tasks = []
prediction_index = 1

if model_has_versions:
log.info("Checking schema backwards compatibility")
Expand All @@ -388,8 +389,10 @@ def cog_safe_push(
fuzz_fixed_inputs=fuzz_fixed_inputs,
fuzz_disabled_inputs=fuzz_disabled_inputs,
fuzz_prompt=fuzz_prompt,
prediction_index=prediction_index,
)
)
prediction_index += 1

if test_cases:
for inputs, checker in test_cases:
Expand All @@ -399,8 +402,10 @@ def cog_safe_push(
inputs=inputs,
checker=checker,
predict_timeout=predict_timeout,
prediction_index=prediction_index,
)
)
prediction_index += 1

if fuzz_iterations > 0:
fuzz_inputs_queue = Queue(maxsize=fuzz_iterations)
Expand All @@ -420,8 +425,10 @@ def cog_safe_push(
context=task_context,
inputs_queue=fuzz_inputs_queue,
predict_timeout=predict_timeout,
prediction_index=prediction_index,
)
)
prediction_index += 1

asyncio.run(run_tasks(tasks, parallel=parallel))

Expand All @@ -443,14 +450,18 @@ async def run_tasks(tasks: list[Task], parallel: int) -> None:
log.info(f"Running tasks with parallelism {parallel}")

semaphore = asyncio.Semaphore(parallel)
errors: list[Exception] = []
errors: list[tuple[Exception, int | None]] = []

async def run_with_semaphore(task: Task) -> None:
async with semaphore:
try:
await task.run()
except Exception as e:
errors.append(e)
# Get prediction index if the task has one
prediction_index = getattr(task, "prediction_index", None)
errors.append((e, prediction_index))
prefix = "" if prediction_index is None else f"[{prediction_index}] "
log.error(f"{prefix}{e}")

# Create task coroutines and run them concurrently
task_coroutines = [run_with_semaphore(task) for task in tasks]
Expand All @@ -459,11 +470,13 @@ async def run_with_semaphore(task: Task) -> None:
await asyncio.gather(*task_coroutines, return_exceptions=True)

if errors:
# If there are multiple errors, we'll raise the first one
# but log all of them
for error in errors[1:]:
log.error(f"Additional error occurred: {error}")
raise errors[0]
# Display all errors with their prediction indices
log.error(f"💥 Tests finished with {len(errors)} error(s):")
for error, prediction_index in errors:
prefix = "" if prediction_index is None else f"[{prediction_index}] "
log.error(f"* {prefix}{error}")

raise TaskExecutionError(f"Encountered {len(errors)} task error(s).", errors)


def parse_inputs(inputs_list: list[str]) -> dict[str, Any]:
Expand Down
10 changes: 6 additions & 4 deletions cog_safe_push/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,11 @@ async def predict(
train_destination: Model | None,
inputs: dict,
timeout_seconds: float,
prediction_index: int | None = None,
) -> tuple[Any | None, str | None]:
prefix = f"[{prediction_index}] " if prediction_index is not None else ""
log.vv(
f"Running {'training' if train else 'prediction'} with inputs:\n{json.dumps(inputs, indent=2)}"
f"{prefix}Running {'training' if train else 'prediction'} with inputs:\n{json.dumps(inputs, indent=2)}"
)

start_time = time.time()
Expand Down Expand Up @@ -261,7 +263,7 @@ async def predict(
else:
raise

log.v(f"Prediction URL: https://replicate.com/p/{prediction.id}")
log.v(f"{prefix}Prediction URL: https://replicate.com/p/{prediction.id}")

while prediction.status not in ["succeeded", "failed", "canceled"]:
await asyncio.sleep(0.5)
Expand All @@ -272,13 +274,13 @@ async def predict(
duration = time.time() - start_time

if prediction.status == "failed":
log.v(f"Got error: {prediction.error} ({duration:.2f} sec)")
log.v(f"{prefix}Got error: {prediction.error} ({duration:.2f} sec)")
return None, prediction.error

output = prediction.output
if _has_output_iterator_array_type(version):
output = "".join(cast("list[str]", output))

log.v(f"Got output: {truncate(output)} ({duration:.2f} sec)")
log.v(f"{prefix}Got output: {truncate(output)} ({duration:.2f} sec)")

return output, None
36 changes: 26 additions & 10 deletions cog_safe_push/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class CheckOutputsMatch(Task):
fuzz_fixed_inputs: dict[str, Any]
fuzz_disabled_inputs: list[str]
fuzz_prompt: str | None
prediction_index: int | None = None

async def run(self) -> None:
if self.first_test_case_inputs is not None:
Expand All @@ -50,37 +51,42 @@ async def run(self) -> None:
fuzz_prompt=self.fuzz_prompt,
)

prefix = (
f"[{self.prediction_index}] " if self.prediction_index is not None else ""
)
log.v(
f"Checking outputs match between existing version and test version, with inputs: {inputs}"
f"{prefix}Checking outputs match between existing version and test version, with inputs: {inputs}"
)
test_output, test_error = await predict(
model=self.context.test_model,
train=self.context.is_train(),
train_destination=self.context.train_destination,
inputs=inputs,
timeout_seconds=self.timeout_seconds,
prediction_index=self.prediction_index,
)
output, error = await predict(
model=self.context.model,
train=self.context.is_train(),
train_destination=self.context.train_destination,
inputs=inputs,
timeout_seconds=self.timeout_seconds,
prediction_index=self.prediction_index,
)

if test_error is not None:
raise OutputsDontMatchError(
f"Existing version raised an error: {test_error}"
f"{prefix}Existing version raised an error: {test_error}"
)
if error is not None:
raise OutputsDontMatchError(f"New version raised an error: {error}")
raise OutputsDontMatchError(f"{prefix}New version raised an error: {error}")

matches, match_error = await outputs_match(
test_output, output, is_deterministic
)
if not matches:
raise OutputsDontMatchError(
f"Outputs don't match:\n\ntest output:\n{test_output}\n\nmodel output:\n{output}\n\n{match_error}"
f"{prefix}Outputs don't match:\n\ntest output:\n{test_output}\n\nmodel output:\n{output}\n\n{match_error}"
)


Expand All @@ -90,15 +96,20 @@ class RunTestCase(Task):
inputs: dict[str, Any]
checker: OutputChecker
predict_timeout: int
prediction_index: int | None = None

async def run(self) -> None:
log.v(f"Running test case with inputs: {self.inputs}")
prefix = (
f"[{self.prediction_index}] " if self.prediction_index is not None else ""
)
log.v(f"{prefix}Running test case with inputs: {self.inputs}")
output, error = await predict(
model=self.context.test_model,
train=self.context.is_train(),
train_destination=self.context.train_destination,
inputs=self.inputs,
timeout_seconds=self.predict_timeout,
prediction_index=self.prediction_index,
)

await self.checker(output, error)
Expand Down Expand Up @@ -138,25 +149,30 @@ class FuzzModel(Task):
context: TaskContext
inputs_queue: Queue[dict[str, Any]]
predict_timeout: int
prediction_index: int | None = None

async def run(self) -> None:
inputs = await asyncio.wait_for(self.inputs_queue.get(), timeout=60)

log.v(f"Fuzzing with inputs: {inputs}")
prefix = (
f"[{self.prediction_index}] " if self.prediction_index is not None else ""
)
log.v(f"{prefix}Fuzzing with inputs: {inputs}")
try:
output, error = await predict(
model=self.context.test_model,
train=self.context.is_train(),
train_destination=self.context.train_destination,
inputs=inputs,
timeout_seconds=self.predict_timeout,
prediction_index=self.prediction_index,
)
except PredictionTimeoutError:
raise FuzzError("Prediction timed out")
raise FuzzError(f"{prefix}Prediction timed out")
if error is not None:
raise FuzzError(f"Prediction raised an error: {error}")
raise FuzzError(f"{prefix}Prediction raised an error: {error}")
if not output:
raise FuzzError("No output")
raise FuzzError(f"{prefix}No output")

if error is not None:
raise FuzzError(f"Prediction failed: {error}")
raise FuzzError(f"{prefix}Prediction failed: {error}")
Loading