diff --git a/src/EdgeGPT/chathub.py b/src/EdgeGPT/chathub.py index 35e31096e..1a371ff74 100644 --- a/src/EdgeGPT/chathub.py +++ b/src/EdgeGPT/chathub.py @@ -3,12 +3,12 @@ import os import ssl import sys +import aiohttp from time import time from typing import Generator from typing import List from typing import Union -from websockets.client import connect, WebSocketClientProtocol import certifi import httpx from BingImageCreator import ImageGenAsync @@ -59,6 +59,11 @@ def __init__( timeout=900, headers=HEADERS_INIT_CONVER, ) + cookies = {} + if self.cookies is not None: + for cookie in self.cookies: + cookies[cookie["name"]] = cookie["value"] + self.aio_session = aiohttp.ClientSession(cookies=cookies) async def get_conversation( self, @@ -99,12 +104,11 @@ async def ask_stream( """ """ # Check if websocket is closed - async with connect( + async with self.aio_session.ws_connect( wss_link or "wss://sydney.bing.com/sydney/ChatHub", - extra_headers=HEADERS, - max_size=None, ssl=ssl_context, - ping_interval=None, + headers=HEADERS, + proxy=self.proxy ) as wss: await self._initial_handshake(wss) # Construct a ChatHub request @@ -116,14 +120,14 @@ async def ask_stream( locale=locale, ) # Send request - await wss.send(append_identifier(self.request.struct)) + await wss.send_str(append_identifier(self.request.struct)) draw = False resp_txt = "" result_text = "" resp_txt_no_link = "" retry_count = 5 while not wss.closed: - msg = await wss.recv() + msg = await wss.receive_str() if not msg: retry_count -= 1 if retry_count == 0: @@ -135,7 +139,7 @@ async def ask_stream( continue for obj in objects: if int(time()) % 6 == 0: - await wss.send(append_identifier({"type": 6})) + await wss.send_str(append_identifier({"type": 6})) if obj is None or not obj: continue response = json.loads(obj) @@ -226,16 +230,16 @@ async def ask_stream( return if response.get("type") != 2: if response.get("type") == 6: - await wss.send(append_identifier({"type": 6})) + await wss.send_str(append_identifier({"type": 6})) elif response.get("type") == 7: - await wss.send(append_identifier({"type": 7})) + await wss.send_str(append_identifier({"type": 7})) elif raw: yield False, response - async def _initial_handshake(self, wss: WebSocketClientProtocol) -> None: - await wss.send(append_identifier({"protocol": "json", "version": 1})) - await wss.recv() - await wss.send(append_identifier({"type": 6})) + async def _initial_handshake(self, wss) -> None: + await wss.send_str(append_identifier({"protocol": "json", "version": 1})) + await wss.receive_str() + await wss.send_str(append_identifier({"type": 6})) async def delete_conversation( self,