Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ public interface Constant {
*/
String SERVER_WEBSOCKET = "server.websocket";

/**
* mqtt gateway 配置
*/
String SERVER_MQTT_GATEWAY = "server.mqtt_gateway";


/**
* ota地址
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ public ResponseEntity<String> activateDevice(
@GetMapping
@Hidden
public ResponseEntity<String> getOTA() {
String mqttUdpConfig = sysParamsService.getValue(Constant.SERVER_MQTT_GATEWAY, false);
if(StringUtils.isBlank(mqttUdpConfig)) {
return ResponseEntity.ok("OTA接口不正常,缺少mqtt_gateway地址,请登录智控台,在参数管理找到【server.mqtt_gateway】配置");
}
String wsUrl = sysParamsService.getValue(Constant.SERVER_WEBSOCKET, true);
if (StringUtils.isBlank(wsUrl) || wsUrl.equals("null")) {
return ResponseEntity.ok("OTA接口不正常,缺少websocket地址,请登录智控台,在参数管理找到【server.websocket】配置");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ public class DeviceReportRespDTO {
@Schema(description = "WebSocket配置")
private Websocket websocket;

@Schema(description = "MQTT Gateway配置")
private MQTT mqtt;

@Getter
@Setter
public static class Firmware {
Expand Down Expand Up @@ -70,4 +73,21 @@ public static class Websocket {
@Schema(description = "WebSocket服务器地址")
private String url;
}
}

@Getter
@Setter
public static class MQTT {
@Schema(description = "MQTT 配置网址")
private String endpoint;
@Schema(description = "MQTT 客户端唯一标识符")
private String client_id;
@Schema(description = "MQTT 认证用户名")
private String username;
@Schema(description = "MQTT 认证密码")
private String password;
@Schema(description = "ESP32 发布消息的主题")
private String publish_topic;
@Schema(description = "ESP32 订阅的主题")
private String subscribe_topic;
}
}
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
package xiaozhi.modules.device.service.impl;

import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.Base64;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.TimeZone;
import java.util.UUID;

import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;

import org.apache.commons.lang3.StringUtils;
import org.springframework.aop.framework.AopContext;
import org.springframework.scheduling.annotation.Async;
Expand All @@ -33,6 +38,7 @@
import xiaozhi.common.utils.ConvertUtils;
import xiaozhi.common.utils.DateUtils;
import xiaozhi.modules.device.dao.DeviceDao;
import xiaozhi.modules.device.dto.DeviceManualAddDTO;
import xiaozhi.modules.device.dto.DevicePageUserDTO;
import xiaozhi.modules.device.dto.DeviceReportReqDTO;
import xiaozhi.modules.device.dto.DeviceReportRespDTO;
Expand All @@ -44,7 +50,6 @@
import xiaozhi.modules.security.user.SecurityUser;
import xiaozhi.modules.sys.service.SysParamsService;
import xiaozhi.modules.sys.service.SysUserUtilService;
import xiaozhi.modules.device.dto.DeviceManualAddDTO;

@Slf4j
@Service
Expand Down Expand Up @@ -176,6 +181,21 @@ public DeviceReportRespDTO checkDeviceActive(String macAddress, String clientId,

response.setWebsocket(websocket);

// 添加MQTT UDP配置
// 从系统参数获取MQTT Gateway地址,仅在配置有效时使用
String mqttUdpConfig = sysParamsService.getValue(Constant.SERVER_MQTT_GATEWAY, false);
if (mqttUdpConfig != null && !mqttUdpConfig.equals("null") && !mqttUdpConfig.isEmpty() && deviceById != null) {
try {
DeviceReportRespDTO.MQTT mqtt = buildMqttConfig(macAddress, clientId, deviceById);
if (mqtt != null) {
mqtt.setEndpoint(mqttUdpConfig);
response.setMqtt(mqtt);
}
} catch (Exception e) {
log.error("生成MQTT配置失败: {}", e.getMessage());
}
}

if (deviceById != null) {
// 如果设备存在,则异步更新上次连接时间和版本信息
String appVersion = deviceReport.getApplication() != null ? deviceReport.getApplication().getVersion()
Expand Down Expand Up @@ -437,4 +457,74 @@ public void manualAddDevice(Long userId, DeviceManualAddDTO dto) {
entity.setAutoUpdate(1);
baseDao.insert(entity);
}

/**
* 生成MQTT密码签名
*
* @param content 签名内容 (clientId + '|' + username)
* @param secretKey 密钥
* @return Base64编码的HMAC-SHA256签名
*/
private String generatePasswordSignature(String content, String secretKey) throws Exception {
Mac hmac = Mac.getInstance("HmacSHA256");
SecretKeySpec keySpec = new SecretKeySpec(secretKey.getBytes(StandardCharsets.UTF_8), "HmacSHA256");
hmac.init(keySpec);
byte[] signature = hmac.doFinal(content.getBytes(StandardCharsets.UTF_8));
return Base64.getEncoder().encodeToString(signature);
}

/**
* 构建MQTT配置信息
*
* @param macAddress MAC地址
* @param clientId 客户端ID (UUID)
* @param device 设备信息
* @return MQTT配置对象
*/
private DeviceReportRespDTO.MQTT buildMqttConfig(String macAddress, String clientId, DeviceEntity device)
throws Exception {
// 从环境变量或系统参数获取签名密钥
String signatureKey = sysParamsService.getValue("server.mqtt_signature_key", false);
if (StringUtils.isBlank(signatureKey)) {
log.warn("缺少MQTT_SIGNATURE_KEY,跳过MQTT配置生成");
return null;
}

// 构建客户端ID格式:groupId@@@macAddress_without_colon@@@uuid
String groupId = device.getBoard() != null ? device.getBoard() : "GID_default";
String deviceIdNoColon = macAddress.replace(":", "_");
String mqttClientId = String.format("%s@@@%s@@@%s", groupId, deviceIdNoColon, clientId);

// 构建用户数据(包含IP等信息)
Map<String, String> userData = new HashMap<>();
// 尝试获取客户端IP
try {
ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder
.getRequestAttributes();
if (attributes != null) {
HttpServletRequest request = attributes.getRequest();
String clientIp = request.getRemoteAddr();
userData.put("ip", clientIp);
}
} catch (Exception e) {
userData.put("ip", "unknown");
}

// 将用户数据编码为Base64 JSON
String userDataJson = new com.fasterxml.jackson.databind.ObjectMapper().writeValueAsString(userData);
String username = Base64.getEncoder().encodeToString(userDataJson.getBytes(StandardCharsets.UTF_8));

// 生成密码签名
String password = generatePasswordSignature(mqttClientId + "|" + username, signatureKey);

// 构建MQTT配置
DeviceReportRespDTO.MQTT mqtt = new DeviceReportRespDTO.MQTT();
mqtt.setClient_id(mqttClientId);
mqtt.setUsername(username);
mqtt.setPassword(password);
mqtt.setPublish_topic("device-server");
mqtt.setSubscribe_topic("devices/p2p/" + deviceIdNoColon);

return mqtt;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
delete from `sys_params` where param_code = 'server.mqtt_gateway';
INSERT INTO `sys_params` (id, param_code, param_value, value_type, param_type, remark) VALUES (116, 'server.mqtt_gateway', 'null', 'string', 1, 'mqtt gateway 配置');

delete from `sys_params` where param_code = 'server.mqtt_signature_key';
INSERT INTO `sys_params` (id, param_code, param_value, value_type, param_type, remark) VALUES (117, 'server.mqtt_signature_key', 'null', 'string', 1, 'mqtt 密钥 配置');

delete from `sys_params` where param_code = 'server.udp_gateway';
INSERT INTO `sys_params` (id, param_code, param_value, value_type, param_type, remark) VALUES (118, 'server.udp_gateway', 'null', 'string', 1, 'udp gateway 配置');
Original file line number Diff line number Diff line change
Expand Up @@ -331,11 +331,17 @@ databaseChangeLog:
- sqlFile:
encoding: utf8
path: classpath:db/changelog/202509091042.sql

- changeSet:
id: 202509091633
author: fyb
changes:
- sqlFile:
encoding: utf8
path: classpath:db/changelog/202509091633.sql
- changeSet:
id: 202509080922
author: fyb
changes:
- sqlFile:
encoding: utf8
path: classpath:db/changelog/202509080922.sql
95 changes: 91 additions & 4 deletions main/xiaozhi-server/core/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ def __init__(
# {"mcp":true} 表示启用MCP功能
self.features = None

# 标记连接是否来自MQTT
self.conn_from_mqtt_gateway = False

# 初始化提示词管理器
self.prompt_manager = PromptManager(config, self.logger)

Expand Down Expand Up @@ -198,6 +201,12 @@ async def handle_connection(self, ws):
self.websocket = ws
self.device_id = self.headers.get("device-id", None)

# 检查是否来自MQTT连接
request_path = ws.request.path
self.conn_from_mqtt_gateway = request_path.endswith("?from=mqtt_gateway")
if self.conn_from_mqtt_gateway:
self.logger.bind(tag=TAG).info("连接来自:MQTT网关")

# 初始化活动时间戳
self.last_activity_time = time.time() * 1000

Expand Down Expand Up @@ -277,12 +286,82 @@ async def _route_message(self, message):
if isinstance(message, str):
await handleTextMessage(self, message)
elif isinstance(message, bytes):
if self.vad is None:
return
if self.asr is None:
if self.vad is None or self.asr is None:
return

# 处理来自MQTT网关的音频包
if self.conn_from_mqtt_gateway and len(message) >= 16:
handled = await self._process_mqtt_audio_message(message)
if handled:
return

# 不需要头部处理或没有头部时,直接处理原始消息
self.asr_audio_queue.put(message)

async def _process_mqtt_audio_message(self, message):
"""
处理来自MQTT网关的音频消息,解析16字节头部并提取音频数据

Args:
message: 包含头部的音频消息

Returns:
bool: 是否成功处理了消息
"""
try:
# 提取头部信息
timestamp = int.from_bytes(message[8:12], "big")
audio_length = int.from_bytes(message[12:16], "big")

# 提取音频数据
if audio_length > 0 and len(message) >= 16 + audio_length:
# 有指定长度,提取精确的音频数据
audio_data = message[16 : 16 + audio_length]
# 基于时间戳进行排序处理
self._process_websocket_audio(audio_data, timestamp)
return True
elif len(message) > 16:
# 没有指定长度或长度无效,去掉头部后处理剩余数据
audio_data = message[16:]
self.asr_audio_queue.put(audio_data)
return True
except Exception as e:
self.logger.bind(tag=TAG).error(f"解析WebSocket音频包失败: {e}")

# 处理失败,返回False表示需要继续处理
return False

def _process_websocket_audio(self, audio_data, timestamp):
"""处理WebSocket格式的音频包"""
# 初始化时间戳序列管理
if not hasattr(self, "audio_timestamp_buffer"):
self.audio_timestamp_buffer = {}
self.last_processed_timestamp = 0
self.max_timestamp_buffer_size = 20

# 如果时间戳是递增的,直接处理
if timestamp >= self.last_processed_timestamp:
self.asr_audio_queue.put(audio_data)
self.last_processed_timestamp = timestamp

# 处理缓冲区中的后续包
processed_any = True
while processed_any:
processed_any = False
for ts in sorted(self.audio_timestamp_buffer.keys()):
if ts > self.last_processed_timestamp:
buffered_audio = self.audio_timestamp_buffer.pop(ts)
self.asr_audio_queue.put(buffered_audio)
self.last_processed_timestamp = ts
processed_any = True
break
else:
# 乱序包,暂存
if len(self.audio_timestamp_buffer) < self.max_timestamp_buffer_size:
self.audio_timestamp_buffer[timestamp] = audio_data
else:
self.asr_audio_queue.put(audio_data)

async def handle_restart(self, message):
"""处理服务器重启请求"""
try:
Expand Down Expand Up @@ -857,7 +936,11 @@ def _handle_function_result(self, result, function_call_data, depth):
{
"id": function_id,
"function": {
"arguments": "{}" if function_arguments == "" else function_arguments,
"arguments": (
"{}"
if function_arguments == ""
else function_arguments
),
"name": function_name,
},
"type": "function",
Expand Down Expand Up @@ -925,6 +1008,10 @@ def clearSpeakStatus(self):
async def close(self, ws=None):
"""资源清理方法"""
try:
# 清理音频缓冲区
if hasattr(self, "audio_buffer"):
self.audio_buffer.clear()

# 取消超时任务
if self.timeout_task and not self.timeout_task.done():
self.timeout_task.cancel()
Expand Down
10 changes: 6 additions & 4 deletions main/xiaozhi-server/core/handle/receiveAudioHandle.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,25 @@ async def handleAudioMessage(conn, audio):
# 接收音频
await conn.asr.receive_audio(conn, audio, have_voice)


async def resume_vad_detection(conn):
# 等待2秒后恢复VAD检测
await asyncio.sleep(1)
conn.just_woken_up = False


async def startToChat(conn, text):
# 检查输入是否是JSON格式(包含说话人信息)
speaker_name = None
actual_text = text

try:
# 尝试解析JSON格式的输入
if text.strip().startswith('{') and text.strip().endswith('}'):
if text.strip().startswith("{") and text.strip().endswith("}"):
data = json.loads(text)
if 'speaker' in data and 'content' in data:
speaker_name = data['speaker']
actual_text = data['content']
if "speaker" in data and "content" in data:
speaker_name = data["speaker"]
actual_text = data["content"]
conn.logger.bind(tag=TAG).info(f"解析到说话人信息: {speaker_name}")

# 直接使用JSON格式的文本,不解析
Expand Down
Loading