|
| 1 | +from dataclasses import dataclass, field |
| 2 | + |
| 3 | +from a2a.types import ( |
| 4 | + Artifact, |
| 5 | + Message, |
| 6 | + Part, |
| 7 | + Role, |
| 8 | + Task, |
| 9 | + TaskArtifactUpdateEvent, |
| 10 | + TaskState, |
| 11 | + TaskStatus, |
| 12 | + TaskStatusUpdateEvent, |
| 13 | + TextPart, |
| 14 | +) |
| 15 | + |
| 16 | + |
| 17 | +@dataclass |
| 18 | +class TaskBuilder: |
| 19 | + id: str = 'task-default' |
| 20 | + context_id: str = 'context-default' |
| 21 | + state: TaskState = TaskState.submitted |
| 22 | + kind: str = 'task' |
| 23 | + artifacts: list = field(default_factory=list) |
| 24 | + history: list = field(default_factory=list) |
| 25 | + metadata: dict = field(default_factory=dict) |
| 26 | + |
| 27 | + def with_id(self, id: str) -> 'TaskBuilder': |
| 28 | + self.id = id |
| 29 | + return self |
| 30 | + |
| 31 | + def with_context_id(self, context_id: str) -> 'TaskBuilder': |
| 32 | + self.context_id = context_id |
| 33 | + return self |
| 34 | + |
| 35 | + def with_state(self, state: TaskState) -> 'TaskBuilder': |
| 36 | + self.state = state |
| 37 | + return self |
| 38 | + |
| 39 | + def with_metadata(self, **kwargs) -> 'TaskBuilder': |
| 40 | + self.metadata.update(kwargs) |
| 41 | + return self |
| 42 | + |
| 43 | + def with_history(self, *messages: Message) -> 'TaskBuilder': |
| 44 | + self.history.extend(messages) |
| 45 | + return self |
| 46 | + |
| 47 | + def with_artifacts(self, *artifacts: Artifact) -> 'TaskBuilder': |
| 48 | + self.artifacts.extend(artifacts) |
| 49 | + return self |
| 50 | + |
| 51 | + def build(self) -> Task: |
| 52 | + return Task( |
| 53 | + id=self.id, |
| 54 | + context_id=self.context_id, |
| 55 | + status=TaskStatus(state=self.state), |
| 56 | + kind=self.kind, |
| 57 | + artifacts=self.artifacts if self.artifacts else None, |
| 58 | + history=self.history if self.history else None, |
| 59 | + metadata=self.metadata if self.metadata else None, |
| 60 | + ) |
| 61 | + |
| 62 | + |
| 63 | +@dataclass |
| 64 | +class MessageBuilder: |
| 65 | + role: Role = Role.user |
| 66 | + text: str = 'default message' |
| 67 | + message_id: str = 'msg-default' |
| 68 | + task_id: str | None = None |
| 69 | + context_id: str | None = None |
| 70 | + |
| 71 | + def as_agent(self) -> 'MessageBuilder': |
| 72 | + self.role = Role.agent |
| 73 | + return self |
| 74 | + |
| 75 | + def as_user(self) -> 'MessageBuilder': |
| 76 | + self.role = Role.user |
| 77 | + return self |
| 78 | + |
| 79 | + def with_text(self, text: str) -> 'MessageBuilder': |
| 80 | + self.text = text |
| 81 | + return self |
| 82 | + |
| 83 | + def with_id(self, message_id: str) -> 'MessageBuilder': |
| 84 | + self.message_id = message_id |
| 85 | + return self |
| 86 | + |
| 87 | + def with_task_id(self, task_id: str) -> 'MessageBuilder': |
| 88 | + self.task_id = task_id |
| 89 | + return self |
| 90 | + |
| 91 | + def with_context_id(self, context_id: str) -> 'MessageBuilder': |
| 92 | + self.context_id = context_id |
| 93 | + return self |
| 94 | + |
| 95 | + def build(self) -> Message: |
| 96 | + return Message( |
| 97 | + role=self.role, |
| 98 | + parts=[Part(TextPart(text=self.text))], |
| 99 | + message_id=self.message_id, |
| 100 | + task_id=self.task_id, |
| 101 | + context_id=self.context_id, |
| 102 | + ) |
| 103 | + |
| 104 | + |
| 105 | +@dataclass |
| 106 | +class ArtifactBuilder: |
| 107 | + artifact_id: str = 'artifact-default' |
| 108 | + name: str = 'default artifact' |
| 109 | + text: str = 'default content' |
| 110 | + description: str | None = None |
| 111 | + |
| 112 | + def with_id(self, artifact_id: str) -> 'ArtifactBuilder': |
| 113 | + self.artifact_id = artifact_id |
| 114 | + return self |
| 115 | + |
| 116 | + def with_name(self, name: str) -> 'ArtifactBuilder': |
| 117 | + self.name = name |
| 118 | + return self |
| 119 | + |
| 120 | + def with_text(self, text: str) -> 'ArtifactBuilder': |
| 121 | + self.text = text |
| 122 | + return self |
| 123 | + |
| 124 | + def with_description(self, description: str) -> 'ArtifactBuilder': |
| 125 | + self.description = description |
| 126 | + return self |
| 127 | + |
| 128 | + def build(self) -> Artifact: |
| 129 | + return Artifact( |
| 130 | + artifact_id=self.artifact_id, |
| 131 | + name=self.name, |
| 132 | + parts=[Part(TextPart(text=self.text))], |
| 133 | + description=self.description, |
| 134 | + ) |
| 135 | + |
| 136 | + |
| 137 | +@dataclass |
| 138 | +class StatusUpdateEventBuilder: |
| 139 | + task_id: str = 'task-default' |
| 140 | + context_id: str = 'context-default' |
| 141 | + state: TaskState = TaskState.working |
| 142 | + message: Message | None = None |
| 143 | + final: bool = False |
| 144 | + metadata: dict = field(default_factory=dict) |
| 145 | + |
| 146 | + def for_task(self, task_id: str) -> 'StatusUpdateEventBuilder': |
| 147 | + self.task_id = task_id |
| 148 | + return self |
| 149 | + |
| 150 | + def with_state(self, state: TaskState) -> 'StatusUpdateEventBuilder': |
| 151 | + self.state = state |
| 152 | + return self |
| 153 | + |
| 154 | + def with_message(self, message: Message) -> 'StatusUpdateEventBuilder': |
| 155 | + self.message = message |
| 156 | + return self |
| 157 | + |
| 158 | + def as_final(self) -> 'StatusUpdateEventBuilder': |
| 159 | + self.final = True |
| 160 | + return self |
| 161 | + |
| 162 | + def with_metadata(self, **kwargs) -> 'StatusUpdateEventBuilder': |
| 163 | + self.metadata.update(kwargs) |
| 164 | + return self |
| 165 | + |
| 166 | + def build(self) -> TaskStatusUpdateEvent: |
| 167 | + return TaskStatusUpdateEvent( |
| 168 | + task_id=self.task_id, |
| 169 | + context_id=self.context_id, |
| 170 | + status=TaskStatus(state=self.state, message=self.message), |
| 171 | + final=self.final, |
| 172 | + metadata=self.metadata if self.metadata else None, |
| 173 | + ) |
| 174 | + |
| 175 | + |
| 176 | +@dataclass |
| 177 | +class ArtifactUpdateEventBuilder: |
| 178 | + task_id: str = 'task-default' |
| 179 | + context_id: str = 'context-default' |
| 180 | + artifact: Artifact | None = None |
| 181 | + append: bool = False |
| 182 | + last_chunk: bool = False |
| 183 | + |
| 184 | + def for_task(self, task_id: str) -> 'ArtifactUpdateEventBuilder': |
| 185 | + self.task_id = task_id |
| 186 | + return self |
| 187 | + |
| 188 | + def with_artifact(self, artifact: Artifact) -> 'ArtifactUpdateEventBuilder': |
| 189 | + self.artifact = artifact |
| 190 | + return self |
| 191 | + |
| 192 | + def as_append(self) -> 'ArtifactUpdateEventBuilder': |
| 193 | + self.append = True |
| 194 | + return self |
| 195 | + |
| 196 | + def as_last_chunk(self) -> 'ArtifactUpdateEventBuilder': |
| 197 | + self.last_chunk = True |
| 198 | + return self |
| 199 | + |
| 200 | + def build(self) -> TaskArtifactUpdateEvent: |
| 201 | + if not self.artifact: |
| 202 | + self.artifact = ArtifactBuilder().build() |
| 203 | + return TaskArtifactUpdateEvent( |
| 204 | + task_id=self.task_id, |
| 205 | + context_id=self.context_id, |
| 206 | + artifact=self.artifact, |
| 207 | + append=self.append, |
| 208 | + last_chunk=self.last_chunk, |
| 209 | + ) |
0 commit comments