Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions bsllmner2/cli_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import asyncio
import sys
from pathlib import Path
from typing import List, Tuple

from bsllmner2.client.ollama import ner
from bsllmner2.config import (LOGGER, PROMPT_EXTRACT_FILE_PATH, Config,
Expand All @@ -14,7 +13,7 @@
load_prompt_file, to_result)


def parse_args(args: List[str]) -> Tuple[Config, CliExtractArgs]:
def parse_args(args: list[str]) -> tuple[Config, CliExtractArgs]:
"""
Parse command-line arguments for the bsllmner2 CLI extract mode.

Expand All @@ -33,8 +32,9 @@ def parse_args(args: List[str]) -> Tuple[Config, CliExtractArgs]:
parser.add_argument(
"--mapping",
type=Path,
required=True,
required=False,
help="Path to the mapping file in TSV format.",
default = Path.cwd()
)
parser.add_argument(
"--prompt",
Expand Down
17 changes: 8 additions & 9 deletions bsllmner2/schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional
from typing import Any, Literal, Optional

from ollama import ChatResponse
from pydantic import BaseModel, Field
Expand All @@ -10,7 +10,6 @@

API_VERSION = "1.0.0"


class CliExtractArgs(BaseModel):
"""
Command-line arguments for the bsllmner2 CLI extract mode.
Expand Down Expand Up @@ -67,7 +66,7 @@ class Prompt(BaseModel):
)


BsEntries = List[Dict[str, Any]]
BsEntries = list[dict[str, Any]]


class MappingValue(BaseModel):
Expand All @@ -77,13 +76,13 @@ class MappingValue(BaseModel):
mapping_answer_label: Optional[str]


Mapping = Dict[str, MappingValue] # key: bs_entry accession
Mapping = dict[str, MappingValue] # key: bs_entry accession


class WfInput(BaseModel):
bs_entries: BsEntries
mapping: Mapping
prompt: List[Prompt]
prompt: list[Prompt]
model: str
thinking: Optional[bool] = None
format: Optional[JsonSchemaValue] = None
Expand All @@ -95,7 +94,7 @@ class LlmOutput(BaseModel):
accession: str
output: Optional[Any] = None
output_full: Optional[str] = None
characteristics: Optional[Dict[str, Any]] = None
characteristics: Optional[dict[str, Any]] = None
taxId: Optional[Any] = None
chat_response: ChatResponse

Expand Down Expand Up @@ -135,8 +134,8 @@ class ErrorLog(BaseModel):

class Result(BaseModel):
input: WfInput
output: List[LlmOutput] = []
evaluation: List[Evaluation] = []
metrics: Optional[List[Metrics]] = None
output: list[LlmOutput] = []
evaluation: list[Evaluation] = []
metrics: Optional[list[Metrics]] = None
run_metadata: RunMetadata
error_log: Optional[ErrorLog] = None
11 changes: 7 additions & 4 deletions bsllmner2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,13 @@ def load_mapping(path: Path) -> Mapping:

mapping: Mapping = {}

with path.open("r", encoding="utf-8") as f:
lines = [line.rstrip("\n") for line in f if line.strip()]
if not lines:
return {}
if path.is_file():
with path.open("rt", encoding="utf-8") as f:
lines = [line.rstrip("\n") for line in f if line.strip()]
if not lines:
return mapping
else:
return mapping

header_fields = lines[0].split("\t")
if header_fields != HEADERS:
Expand Down