feat: 添加实时会议WebSocket支持和相关服务

- 创建 `RealtimeMeetingSocketSessionService` 及其实现类,用于创建和获取实时会议会话
- 添加 `WebSocketSecurityConfig` 以配置WebSocket安全
- 创建 `RealtimeMeetingProxyWebSocketHandler` 处理WebSocket消息代理
- 配置 `RealtimeMeetingWebSocketConfig` 注
dev_na
chenhao 2026-03-31 09:54:08 +08:00
parent 9d1a8710af
commit f9c0d31b87
10 changed files with 734 additions and 0 deletions

View File

@ -0,0 +1,22 @@
package com.imeeting.config;
import com.imeeting.websocket.RealtimeMeetingProxyWebSocketHandler;
import lombok.RequiredArgsConstructor;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
@Configuration
@EnableWebSocket
@RequiredArgsConstructor
public class RealtimeMeetingWebSocketConfig implements WebSocketConfigurer {
private final RealtimeMeetingProxyWebSocketHandler realtimeMeetingProxyWebSocketHandler;
@Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
registry.addHandler(realtimeMeetingProxyWebSocketHandler, "/ws/meeting/realtime")
.setAllowedOriginPatterns("*");
}
}

View File

@ -0,0 +1,21 @@
package com.imeeting.config;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.annotation.Order;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer;
import org.springframework.security.web.SecurityFilterChain;
@Configuration
public class WebSocketSecurityConfig {
@Bean
@Order(0)
public SecurityFilterChain webSocketSecurityFilterChain(HttpSecurity http) throws Exception {
http.securityMatcher("/ws/**")
.csrf(AbstractHttpConfigurer::disable)
.authorizeHttpRequests(authorize -> authorize.anyRequest().permitAll());
return http.build();
}
}

View File

@ -0,0 +1,16 @@
package com.imeeting.dto.biz;
import lombok.Data;
import java.math.BigDecimal;
import java.util.List;
@Data
public class AiLocalProfileVO {
private List<String> asrModels;
private List<String> speakerModels;
private String activeAsrModel;
private String activeSpeakerModel;
private BigDecimal svThreshold;
private String wsEndpoint;
}

View File

@ -0,0 +1,19 @@
package com.imeeting.dto.biz;
import lombok.Data;
import java.util.List;
import java.util.Map;
@Data
public class OpenRealtimeSocketSessionCommand {
private Long asrModelId;
private String mode;
private String language;
private Integer useSpkId;
private Boolean enablePunctuation;
private Boolean enableItn;
private Boolean enableTextRefine;
private Boolean saveAudio;
private List<Map<String, Object>> hotwords;
}

View File

@ -0,0 +1,12 @@
package com.imeeting.dto.biz;
import lombok.Data;
@Data
public class RealtimeSocketSessionData {
private Long meetingId;
private Long userId;
private Long tenantId;
private Long asrModelId;
private String targetWsUrl;
}

View File

@ -0,0 +1,13 @@
package com.imeeting.dto.biz;
import lombok.Data;
import java.util.Map;
@Data
public class RealtimeSocketSessionVO {
private String sessionToken;
private String path;
private Long expiresInSeconds;
private Map<String, Object> startMessage;
}

View File

@ -0,0 +1,10 @@
package com.imeeting.dto.biz;
import lombok.Data;
@Data
public class UpdateMeetingTranscriptCommand {
private Long meetingId;
private Long transcriptId;
private String content;
}

View File

@ -0,0 +1,17 @@
package com.imeeting.service.biz;
import com.imeeting.dto.biz.RealtimeSocketSessionData;
import com.imeeting.dto.biz.RealtimeSocketSessionVO;
import com.unisbase.security.LoginUser;
import java.util.List;
import java.util.Map;
public interface RealtimeMeetingSocketSessionService {
RealtimeSocketSessionVO createSession(Long meetingId, Long asrModelId, String mode, String language,
Integer useSpkId, Boolean enablePunctuation, Boolean enableItn,
Boolean enableTextRefine, Boolean saveAudio,
List<Map<String, Object>> hotwords, LoginUser loginUser);
RealtimeSocketSessionData getSessionData(String sessionToken);
}

View File

@ -0,0 +1,182 @@
package com.imeeting.service.biz.impl;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.imeeting.common.RedisKeys;
import com.imeeting.dto.biz.AiModelVO;
import com.imeeting.dto.biz.RealtimeSocketSessionData;
import com.imeeting.dto.biz.RealtimeSocketSessionVO;
import com.imeeting.entity.biz.Meeting;
import com.imeeting.service.biz.AiModelService;
import com.imeeting.service.biz.MeetingAccessService;
import com.imeeting.service.biz.RealtimeMeetingSocketSessionService;
import com.unisbase.security.LoginUser;
import lombok.RequiredArgsConstructor;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Service;
import java.time.Duration;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
@Service
@RequiredArgsConstructor
public class RealtimeMeetingSocketSessionServiceImpl implements RealtimeMeetingSocketSessionService {
private static final Duration SESSION_TTL = Duration.ofMinutes(10);
private static final String WS_PATH = "/ws/meeting/realtime";
private final ObjectMapper objectMapper;
private final StringRedisTemplate redisTemplate;
private final MeetingAccessService meetingAccessService;
private final AiModelService aiModelService;
@Override
public RealtimeSocketSessionVO createSession(Long meetingId, Long asrModelId, String mode, String language,
Integer useSpkId, Boolean enablePunctuation, Boolean enableItn,
Boolean enableTextRefine, Boolean saveAudio,
List<Map<String, Object>> hotwords, LoginUser loginUser) {
if (meetingId == null) {
throw new RuntimeException("Meeting ID is required");
}
if (asrModelId == null) {
throw new RuntimeException("ASR model ID is required");
}
Meeting meeting = meetingAccessService.requireMeeting(meetingId);
meetingAccessService.assertCanManageRealtimeMeeting(meeting, loginUser);
AiModelVO asrModel = aiModelService.getModelById(asrModelId, "ASR");
if (asrModel == null) {
throw new RuntimeException("ASR model not found");
}
String targetWsUrl = resolveWsUrl(asrModel);
if (targetWsUrl == null || targetWsUrl.isBlank()) {
throw new RuntimeException("ASR model WebSocket is not configured");
}
RealtimeSocketSessionData sessionData = new RealtimeSocketSessionData();
sessionData.setMeetingId(meetingId);
sessionData.setUserId(loginUser.getUserId());
sessionData.setTenantId(loginUser.getTenantId());
sessionData.setAsrModelId(asrModelId);
sessionData.setTargetWsUrl(targetWsUrl);
String sessionToken = UUID.randomUUID().toString().replace("-", "");
try {
redisTemplate.opsForValue().set(
RedisKeys.realtimeMeetingSocketSessionKey(sessionToken),
objectMapper.writeValueAsString(sessionData),
SESSION_TTL
);
} catch (Exception ex) {
throw new RuntimeException("Failed to create realtime socket session", ex);
}
RealtimeSocketSessionVO vo = new RealtimeSocketSessionVO();
vo.setSessionToken(sessionToken);
vo.setPath(WS_PATH);
vo.setExpiresInSeconds(SESSION_TTL.toSeconds());
vo.setStartMessage(buildStartMessage(
asrModel,
meetingId,
mode,
language,
useSpkId,
enablePunctuation,
enableItn,
enableTextRefine,
saveAudio,
hotwords
));
return vo;
}
@Override
public RealtimeSocketSessionData getSessionData(String sessionToken) {
if (sessionToken == null || sessionToken.isBlank()) {
return null;
}
String raw = redisTemplate.opsForValue().get(RedisKeys.realtimeMeetingSocketSessionKey(sessionToken));
if (raw == null || raw.isBlank()) {
return null;
}
try {
return objectMapper.readValue(raw, RealtimeSocketSessionData.class);
} catch (Exception ex) {
throw new RuntimeException("Failed to read realtime socket session", ex);
}
}
private String resolveWsUrl(AiModelVO model) {
if (model.getWsUrl() != null && !model.getWsUrl().isBlank()) {
return model.getWsUrl();
}
if (model.getBaseUrl() == null || model.getBaseUrl().isBlank()) {
return "";
}
return model.getBaseUrl()
.replaceFirst("^http://", "ws://")
.replaceFirst("^https://", "wss://");
}
private Map<String, Object> buildStartMessage(AiModelVO model, Long meetingId, String mode, String language,
Integer useSpkId, Boolean enablePunctuation, Boolean enableItn,
Boolean enableTextRefine, Boolean saveAudio,
List<Map<String, Object>> hotwords) {
Map<String, Object> root = new HashMap<>();
root.put("type", "start");
root.put("request_id", "web_" + System.currentTimeMillis() + "_" + meetingId);
root.put("authorization", buildAuthorization(model.getApiKey()));
Map<String, Object> config = new HashMap<>();
Map<String, Object> audio = new HashMap<>();
audio.put("format", "pcm");
audio.put("sample_rate", 16000);
audio.put("channels", 1);
config.put("audio", audio);
Map<String, Object> recognition = new HashMap<>();
recognition.put("language", normalizeLanguage(language));
recognition.put("enable_punctuation", boolOrDefault(enablePunctuation, true));
recognition.put("enable_itn", boolOrDefault(enableItn, true));
recognition.put("enable_speaker", Integer.valueOf(1).equals(useSpkId));
recognition.put("enable_two_pass", !"online".equalsIgnoreCase(mode));
recognition.put("enable_text_refine", boolOrDefault(enableTextRefine, false));
recognition.put("speaker_threshold", readSpeakerThreshold(model.getMediaConfig()));
recognition.put("hotwords", hotwords == null ? List.of() : hotwords);
config.put("recognition", recognition);
config.put("model", model.getModelCode());
config.put("save_audio", boolOrDefault(saveAudio, false));
root.put("config", config);
return root;
}
private String buildAuthorization(String apiKey) {
if (apiKey == null || apiKey.isBlank()) {
return "";
}
return apiKey.startsWith("Bearer ") ? apiKey : "Bearer " + apiKey;
}
private Object readSpeakerThreshold(Map<String, Object> mediaConfig) {
if (mediaConfig == null) {
return null;
}
return mediaConfig.get("svThreshold");
}
private String normalizeLanguage(String language) {
if (language == null || language.isBlank()) {
return "auto";
}
return language.trim();
}
private boolean boolOrDefault(Boolean value, boolean defaultValue) {
return value != null ? value : defaultValue;
}
}

View File

@ -0,0 +1,422 @@
package com.imeeting.websocket;
import com.imeeting.dto.biz.RealtimeSocketSessionData;
import com.imeeting.service.biz.RealtimeMeetingSocketSessionService;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.BinaryMessage;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.PingMessage;
import org.springframework.web.socket.PongMessage;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.AbstractWebSocketHandler;
import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator;
import java.io.ByteArrayOutputStream;
import java.net.URI;
import java.net.URLDecoder;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.function.Supplier;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicInteger;
@Slf4j
@Component
@RequiredArgsConstructor
public class RealtimeMeetingProxyWebSocketHandler extends AbstractWebSocketHandler {
private static final String ATTR_FRONTEND_SESSION = "frontendSession";
private static final String ATTR_UPSTREAM_SOCKET = "upstreamSocket";
private static final String ATTR_MEETING_ID = "meetingId";
private static final String ATTR_TARGET_WS_URL = "targetWsUrl";
private static final String ATTR_FRONTEND_TEXT_COUNT = "frontendTextCount";
private static final String ATTR_FRONTEND_BINARY_COUNT = "frontendBinaryCount";
private static final String ATTR_UPSTREAM_SEND_CHAIN = "upstreamSendChain";
private static final String ATTR_START_MESSAGE_SENT = "startMessageSent";
private static final String ATTR_PENDING_AUDIO_FRAMES = "pendingAudioFrames";
private static final CompletableFuture<Void> COMPLETED = CompletableFuture.completedFuture(null);
private final RealtimeMeetingSocketSessionService realtimeMeetingSocketSessionService;
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
String sessionToken = extractQueryParam(session.getUri(), "sessionToken");
RealtimeSocketSessionData sessionData = realtimeMeetingSocketSessionService.getSessionData(sessionToken);
if (sessionData == null) {
log.warn("Realtime websocket rejected: invalid session token, sessionId={}", session.getId());
session.close(CloseStatus.POLICY_VIOLATION.withReason("Invalid realtime socket session"));
return;
}
ConcurrentWebSocketSessionDecorator frontendSession =
new ConcurrentWebSocketSessionDecorator(session, (int) Duration.ofSeconds(15).toMillis(), 1024 * 1024);
session.getAttributes().put(ATTR_FRONTEND_SESSION, frontendSession);
session.getAttributes().put(ATTR_MEETING_ID, sessionData.getMeetingId());
session.getAttributes().put(ATTR_TARGET_WS_URL, sessionData.getTargetWsUrl());
session.getAttributes().put(ATTR_FRONTEND_TEXT_COUNT, new AtomicInteger());
session.getAttributes().put(ATTR_FRONTEND_BINARY_COUNT, new AtomicInteger());
session.getAttributes().put(ATTR_UPSTREAM_SEND_CHAIN, COMPLETED);
session.getAttributes().put(ATTR_START_MESSAGE_SENT, Boolean.FALSE);
session.getAttributes().put(ATTR_PENDING_AUDIO_FRAMES, new ArrayList<byte[]>());
log.info("Realtime websocket accepted: meetingId={}, sessionId={}, upstream={}",
sessionData.getMeetingId(), session.getId(), sessionData.getTargetWsUrl());
java.net.http.WebSocket upstreamSocket;
try {
upstreamSocket = java.net.http.HttpClient.newHttpClient()
.newWebSocketBuilder()
.buildAsync(URI.create(sessionData.getTargetWsUrl()),
new UpstreamListener(frontendSession, session, sessionData.getMeetingId(), sessionData.getTargetWsUrl()))
.get();
} catch (InterruptedException ex) {
Thread.currentThread().interrupt();
log.error("Realtime websocket upstream connect interrupted: meetingId={}, sessionId={}",
sessionData.getMeetingId(), session.getId(), ex);
frontendSession.close(CloseStatus.SERVER_ERROR.withReason("Interrupted while connecting upstream"));
return;
} catch (ExecutionException | CompletionException ex) {
log.warn("Failed to connect upstream websocket, meetingId={}, target={}", sessionData.getMeetingId(), sessionData.getTargetWsUrl(), ex);
frontendSession.close(CloseStatus.SERVER_ERROR.withReason("Failed to connect ASR websocket"));
return;
}
session.getAttributes().put(ATTR_UPSTREAM_SOCKET, upstreamSocket);
}
@Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) {
java.net.http.WebSocket upstreamSocket = getUpstreamSocket(session);
if (upstreamSocket == null) {
log.warn("Frontend text ignored because upstream socket is unavailable, meetingId={}, sessionId={}",
session.getAttributes().get(ATTR_MEETING_ID), session.getId());
return;
}
int count = nextCount(session, ATTR_FRONTEND_TEXT_COUNT);
log.info("Frontend text -> upstream: meetingId={}, sessionId={}, count={}, payload={}",
session.getAttributes().get(ATTR_MEETING_ID), session.getId(), count, summarizeText(message.getPayload()));
sendUpstreamOrdered(session, () -> upstreamSocket.sendText(message.getPayload(), true), "text");
if (looksLikeStartMessage(message.getPayload())) {
session.getAttributes().put(ATTR_START_MESSAGE_SENT, Boolean.TRUE);
flushPendingAudioFrames(session, upstreamSocket);
}
}
@Override
protected void handleBinaryMessage(WebSocketSession session, BinaryMessage message) {
java.net.http.WebSocket upstreamSocket = getUpstreamSocket(session);
if (upstreamSocket == null) {
log.warn("Frontend binary ignored because upstream socket is unavailable, meetingId={}, sessionId={}",
session.getAttributes().get(ATTR_MEETING_ID), session.getId());
return;
}
int count = nextCount(session, ATTR_FRONTEND_BINARY_COUNT);
int bytes = message.getPayloadLength();
if (shouldLogBinaryFrame(count)) {
log.info("Frontend binary -> upstream: meetingId={}, sessionId={}, count={}, bytes={}",
session.getAttributes().get(ATTR_MEETING_ID), session.getId(), count, bytes);
}
byte[] payload = toByteArray(message.getPayload());
if (!Boolean.TRUE.equals(session.getAttributes().get(ATTR_START_MESSAGE_SENT))) {
queuePendingAudioFrame(session, payload);
if (shouldLogBinaryFrame(count)) {
log.warn("Frontend binary queued before start message: meetingId={}, sessionId={}, count={}, bytes={}",
session.getAttributes().get(ATTR_MEETING_ID), session.getId(), count, bytes);
}
return;
}
sendUpstreamOrdered(session, () -> upstreamSocket.sendBinary(ByteBuffer.wrap(payload), true), "binary");
}
@Override
protected void handlePongMessage(WebSocketSession session, PongMessage message) {
java.net.http.WebSocket upstreamSocket = getUpstreamSocket(session);
if (upstreamSocket == null) {
return;
}
sendUpstreamOrdered(session, () -> upstreamSocket.sendPong(copyBuffer(message.getPayload())), "pong");
}
@Override
public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
log.error("Realtime websocket transport error: meetingId={}, sessionId={}, upstream={}",
session.getAttributes().get(ATTR_MEETING_ID), session.getId(), session.getAttributes().get(ATTR_TARGET_WS_URL), exception);
closeUpstreamSocket(session, CloseStatus.SERVER_ERROR);
if (session.isOpen()) {
session.close(CloseStatus.SERVER_ERROR);
}
}
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus status) {
log.info("Realtime websocket closed: meetingId={}, sessionId={}, code={}, reason={}",
session.getAttributes().get(ATTR_MEETING_ID), session.getId(), status.getCode(), status.getReason());
closeUpstreamSocket(session, status);
}
private java.net.http.WebSocket getUpstreamSocket(WebSocketSession session) {
Object socket = session.getAttributes().get(ATTR_UPSTREAM_SOCKET);
if (socket instanceof java.net.http.WebSocket webSocket) {
return webSocket;
}
return null;
}
private void closeUpstreamSocket(WebSocketSession session, CloseStatus status) {
java.net.http.WebSocket upstreamSocket = getUpstreamSocket(session);
if (upstreamSocket != null) {
upstreamSocket.sendClose(status.getCode(), status.getReason() == null ? "" : status.getReason());
session.getAttributes().remove(ATTR_UPSTREAM_SOCKET);
}
}
private String extractQueryParam(URI uri, String key) {
if (uri == null || uri.getQuery() == null || uri.getQuery().isBlank()) {
return null;
}
return Arrays.stream(uri.getQuery().split("&"))
.map(item -> item.split("=", 2))
.filter(parts -> parts.length == 2 && key.equals(parts[0]))
.map(parts -> URLDecoder.decode(parts[1], StandardCharsets.UTF_8))
.findFirst()
.orElse(null);
}
private ByteBuffer copyBuffer(ByteBuffer source) {
ByteBuffer duplicate = source.asReadOnlyBuffer();
byte[] bytes = new byte[duplicate.remaining()];
duplicate.get(bytes);
return ByteBuffer.wrap(bytes);
}
private byte[] toByteArray(ByteBuffer source) {
ByteBuffer duplicate = source.asReadOnlyBuffer();
byte[] bytes = new byte[duplicate.remaining()];
duplicate.get(bytes);
return bytes;
}
private int nextCount(WebSocketSession session, String key) {
Object value = session.getAttributes().get(key);
if (value instanceof AtomicInteger counter) {
return counter.incrementAndGet();
}
return 0;
}
@SuppressWarnings("unchecked")
private void sendUpstreamOrdered(WebSocketSession session, Supplier<CompletableFuture<?>> sendAction, String messageType) {
synchronized (session) {
CompletableFuture<Void> chain = (CompletableFuture<Void>) session.getAttributes()
.getOrDefault(ATTR_UPSTREAM_SEND_CHAIN, COMPLETED);
CompletableFuture<Void> nextChain = chain
.exceptionally(ex -> null)
.thenCompose(ignored -> sendAction.get().thenApply(ignoredResult -> null));
nextChain = nextChain.whenComplete((ignored, ex) -> {
if (ex != null) {
log.error("Ordered upstream send failed: meetingId={}, sessionId={}, type={}",
session.getAttributes().get(ATTR_MEETING_ID), session.getId(), messageType, ex);
}
});
session.getAttributes().put(ATTR_UPSTREAM_SEND_CHAIN, nextChain);
}
}
private static boolean shouldLogBinaryFrame(int count) {
return count <= 3 || count % 25 == 0;
}
private static String summarizeText(String payload) {
if (payload == null) {
return "";
}
String normalized = payload.replaceAll("\\s+", " ").trim();
if (normalized.length() <= 240) {
return normalized;
}
return normalized.substring(0, 240) + "...";
}
private boolean looksLikeStartMessage(String payload) {
if (payload == null || payload.isBlank()) {
return false;
}
String normalized = payload.replaceAll("\\s+", "");
return normalized.contains("\"type\":\"start\"");
}
@SuppressWarnings("unchecked")
private void queuePendingAudioFrame(WebSocketSession session, byte[] payload) {
synchronized (session) {
List<byte[]> pendingFrames = (List<byte[]>) session.getAttributes().get(ATTR_PENDING_AUDIO_FRAMES);
if (pendingFrames == null) {
pendingFrames = new ArrayList<>();
session.getAttributes().put(ATTR_PENDING_AUDIO_FRAMES, pendingFrames);
}
pendingFrames.add(payload);
}
}
@SuppressWarnings("unchecked")
private void flushPendingAudioFrames(WebSocketSession session, java.net.http.WebSocket upstreamSocket) {
List<byte[]> pendingFrames;
synchronized (session) {
pendingFrames = (List<byte[]>) session.getAttributes().get(ATTR_PENDING_AUDIO_FRAMES);
if (pendingFrames == null || pendingFrames.isEmpty()) {
return;
}
session.getAttributes().put(ATTR_PENDING_AUDIO_FRAMES, new ArrayList<byte[]>());
}
log.info("Flushing queued audio frames after start message: meetingId={}, sessionId={}, frameCount={}",
session.getAttributes().get(ATTR_MEETING_ID), session.getId(), pendingFrames.size());
for (byte[] frame : pendingFrames) {
sendUpstreamOrdered(session, () -> upstreamSocket.sendBinary(ByteBuffer.wrap(frame), true), "binary-flush");
}
}
private static final class UpstreamListener implements java.net.http.WebSocket.Listener {
private final ConcurrentWebSocketSessionDecorator frontendSession;
private final WebSocketSession rawSession;
private final Long meetingId;
private final String targetWsUrl;
private final StringBuilder textBuffer = new StringBuilder();
private final ByteArrayOutputStream binaryBuffer = new ByteArrayOutputStream();
private final AtomicInteger upstreamTextCount = new AtomicInteger();
private final AtomicInteger upstreamBinaryCount = new AtomicInteger();
private UpstreamListener(ConcurrentWebSocketSessionDecorator frontendSession, WebSocketSession rawSession,
Long meetingId, String targetWsUrl) {
this.frontendSession = frontendSession;
this.rawSession = rawSession;
this.meetingId = meetingId;
this.targetWsUrl = targetWsUrl;
}
@Override
public void onOpen(java.net.http.WebSocket webSocket) {
log.info("Upstream websocket opened: meetingId={}, sessionId={}, upstream={}",
meetingId, rawSession.getId(), targetWsUrl);
webSocket.request(1);
}
@Override
public java.util.concurrent.CompletionStage<?> onText(java.net.http.WebSocket webSocket, CharSequence data, boolean last) {
textBuffer.append(data);
if (last) {
int count = upstreamTextCount.incrementAndGet();
try {
if (frontendSession.isOpen()) {
frontendSession.sendMessage(new TextMessage(textBuffer.toString()));
}
log.info("Upstream text -> frontend: meetingId={}, sessionId={}, count={}, payload={}",
meetingId, rawSession.getId(), count, summarizeText(textBuffer.toString()));
} catch (Exception ex) {
log.error("Failed to forward upstream text: meetingId={}, sessionId={}", meetingId, rawSession.getId(), ex);
closeFrontend(CloseStatus.SERVER_ERROR);
} finally {
textBuffer.setLength(0);
}
}
webSocket.request(1);
return COMPLETED;
}
@Override
public java.util.concurrent.CompletionStage<?> onBinary(java.net.http.WebSocket webSocket, ByteBuffer data, boolean last) {
byte[] chunk = new byte[data.remaining()];
data.get(chunk);
binaryBuffer.writeBytes(chunk);
if (last) {
int count = upstreamBinaryCount.incrementAndGet();
try {
if (frontendSession.isOpen()) {
frontendSession.sendMessage(new BinaryMessage(binaryBuffer.toByteArray()));
}
if (shouldLogBinaryFrame(count)) {
log.info("Upstream binary -> frontend: meetingId={}, sessionId={}, count={}, bytes={}",
meetingId, rawSession.getId(), count, binaryBuffer.size());
}
} catch (Exception ex) {
log.error("Failed to forward upstream binary: meetingId={}, sessionId={}", meetingId, rawSession.getId(), ex);
closeFrontend(CloseStatus.SERVER_ERROR);
} finally {
binaryBuffer.reset();
}
}
webSocket.request(1);
return COMPLETED;
}
@Override
public java.util.concurrent.CompletionStage<?> onPing(java.net.http.WebSocket webSocket, ByteBuffer message) {
try {
if (frontendSession.isOpen()) {
frontendSession.sendMessage(new PingMessage(copyBuffer(message)));
}
log.info("Upstream ping -> frontend: meetingId={}, sessionId={}, bytes={}",
meetingId, rawSession.getId(), message.remaining());
} catch (Exception ex) {
log.error("Failed to forward upstream ping: meetingId={}, sessionId={}", meetingId, rawSession.getId(), ex);
closeFrontend(CloseStatus.SERVER_ERROR);
}
webSocket.request(1);
return COMPLETED;
}
@Override
public java.util.concurrent.CompletionStage<?> onPong(java.net.http.WebSocket webSocket, ByteBuffer message) {
try {
if (frontendSession.isOpen()) {
frontendSession.sendMessage(new PongMessage(copyBuffer(message)));
}
log.info("Upstream pong -> frontend: meetingId={}, sessionId={}, bytes={}",
meetingId, rawSession.getId(), message.remaining());
} catch (Exception ex) {
log.error("Failed to forward upstream pong: meetingId={}, sessionId={}", meetingId, rawSession.getId(), ex);
closeFrontend(CloseStatus.SERVER_ERROR);
}
webSocket.request(1);
return COMPLETED;
}
@Override
public java.util.concurrent.CompletionStage<?> onClose(java.net.http.WebSocket webSocket, int statusCode, String reason) {
log.info("Upstream websocket closed: meetingId={}, sessionId={}, code={}, reason={}",
meetingId, rawSession.getId(), statusCode, reason);
closeFrontend(new CloseStatus(statusCode, reason));
return COMPLETED;
}
@Override
public void onError(java.net.http.WebSocket webSocket, Throwable error) {
log.error("Upstream websocket error: meetingId={}, sessionId={}, upstream={}",
meetingId, rawSession.getId(), targetWsUrl, error);
closeFrontend(CloseStatus.SERVER_ERROR);
}
private void closeFrontend(CloseStatus status) {
try {
if (rawSession.isOpen()) {
rawSession.close(status);
}
} catch (Exception ignored) {
// ignore close failure
}
}
private ByteBuffer copyBuffer(ByteBuffer source) {
ByteBuffer duplicate = source.asReadOnlyBuffer();
byte[] bytes = new byte[duplicate.remaining()];
duplicate.get(bytes);
return ByteBuffer.wrap(bytes);
}
}
}