-
Notifications
You must be signed in to change notification settings - Fork 0
/
api.py
263 lines (204 loc) · 8.91 KB
/
api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
import os
import logging
import sys
from datetime import datetime, timedelta
import pandas as pd
from flask import Flask
from flask_socketio import SocketIO, emit
from models import Generator, Predictor, Guard
from utils import get_ids, get_readable_codes
Log_Format = "%(levelname)s %(asctime)s - %(message)s"
logging.basicConfig(
stream = sys.stdout,
filemode = "w",
format = Log_Format,
level = logging.WARNING)
logger = logging.getLogger()
def get_est_now():
return datetime.now() - timedelta(hours=5)
###############################################################################
# Load models #
###############################################################################
base_path = "/home/ubuntu/models/"
# base_path = "/home/shangling/"
# base_path = "/Users/shanglinghsu/dummy_models/"
predictor = Predictor(base_path + 'predictors/')
generator = Generator(base_path + 'generator')
guard = Guard(base_path)
###############################################################################
# Initialize data variables #
###############################################################################
# Constants
DATETIME_FORMAT = "%Y-%m-%d_%H-%M-%S"
DATE_MICROSEC_FORMAT = "%Y-%m-%d_%H-%M-%S.%f"
SAVE_PATH = base_path + "flask_outputs"
ID_PATH = base_path + "chat_user_ids.csv"
DIALOG_COLUMNS = ['user_id', 'is_listener', 'utterance', 'time', 'predictor_input_ids', 'generator_input_ids']
PRED_COLUMNS = ['code', 'score', 'last_utterance_index', 'pred_index', 'text', 'time']
CLICK_COLUMNS = ['last_utterance_index', 'pred_index', 'time']
# Create log save path
os.makedirs(SAVE_PATH, exist_ok=True)
# Credentials
user_ids, chat_ids, listener_chat_types = get_ids(ID_PATH)
# Mutables
client_id, listener_id, current_chat_id = "default_client", "default_listener", "default"
dialog_df = pd.DataFrame.from_dict({k: [] for k in DIALOG_COLUMNS})
pred_df = pd.DataFrame.from_dict({k: [] for k in PRED_COLUMNS})
click_df = pd.DataFrame.from_dict({k: [] for k in CLICK_COLUMNS})
def reset_session():
global dialog_df, pred_df, click_df
global client_id, listener_id, current_chat_id
client_id, listener_id, current_chat_id = "default_client", "default_listener", "default"
dialog_df = dialog_df[0:0]
pred_df = pred_df[0:0]
click_df = click_df[0:0]
###############################################################################
# Actually setup the Api resource #
###############################################################################
app = Flask(__name__)
socketio = SocketIO(app, logger=logger, engineio_logger=logger, cors_allowed_origins="*", ping_interval=60)
logger.info("Backend ready")
###############################################################################
# Events #
###############################################################################
@socketio.on("log_user")
def log_user(chat_id, user_id):
"""Record who are involved in this conversation and validate users
Args:
chat_id (str): chat id assigned to the user
user_id (str): id assigned to the user
Emits "login_response" to the same user
args (dict):
valid (bool): whether the (chat_id, user_id) pair is valid
is_listener (bool): (if valid) whether this user is the Listener in this mock chat
"""
try:
global listener_id, client_id, current_chat_id, user_ids, chat_ids, listener_chat_types
chat_id = chat_id.lower().strip()
user_id = user_id.lower().strip()
# Valid user_id?
if user_id not in user_ids:
emit("login_response", {"valid": False})
return
is_listener, show_suggestions = None, None
# Listener?
if user_id in listener_chat_types.keys():
is_listener = True
# Assigned chat_id?
if chat_id not in listener_chat_types[user_id].keys():
emit("login_response", {"valid": False})
return
listener_id = user_id
show_suggestions = listener_chat_types[user_id][chat_id]
# O.w., client
else:
is_listener = False
# Any existing chat_id?
if chat_id not in chat_ids:
emit("login_response", {"valid": False})
return
client_id = user_id
show_suggestions = False
current_chat_id = chat_id
emit("login_response", {
"valid": True,
"is_listener": is_listener,
"show_suggestions": show_suggestions
})
logger.info("{} logged in successfully as a {}.".format(
user_id, "Listener" if is_listener else "Client"
))
except Exception as e:
emit("error", str(e))
@socketio.on("add_message")
def add_message(is_listener, utterance):
"""Add a new utterance to backend
Args:
is_listener (bool): whether the message is sent by the listener
utterance (str): the new message sent
Emits "new_message" to ALL users
args (dict):
is_listener (bool): whether the new message is sent by the listener
utterance (str): the new message sent
predictions (List[str]): list of predicted next utterance in order
"""
try:
global dialog_df, pred_df, listener_id, client_id
user_id = listener_id if is_listener else client_id
now = get_est_now().strftime(DATE_MICROSEC_FORMAT)
new_row = [user_id, is_listener, utterance, now, [], []]
last_utterance_index = len(dialog_df.index)
dialog_df.loc[last_utterance_index] = new_row
code_scores = predictor.predict(dialog_df)
top_code_scores = list(filter(lambda code_score: code_score[1] > Predictor.PRED_THRESHOLD, code_scores))
generations = generator.predict(dialog_df, top_code_scores)
blacklisted = guard.predict(generations)
# Don't give same generations under different codes
existed_gens = set()
deduped_code_scores, deduped_gens = [], []
for c, g, b in zip(top_code_scores, generations, blacklisted):
if g in existed_gens or b == 1: continue
existed_gens.add(g)
deduped_code_scores.append(c)
deduped_gens.append(g)
now = get_est_now().strftime(DATE_MICROSEC_FORMAT)
for pred_index, (code, score) in enumerate(deduped_code_scores):
pred_df.at[len(pred_df)] = [code, score, last_utterance_index, pred_index, deduped_gens[pred_index], now]
readable_codes = get_readable_codes(deduped_code_scores)
args = {
"is_listener": is_listener,
"utterance": utterance,
"suggestions": readable_codes,
"predictions": deduped_gens,
}
emit("new_message", args, broadcast=True) # Send to all clients
except Exception as e:
emit("error", str(e))
@socketio.on("log_click")
def log_click(index):
"""Record when, who, and what is clicked
Args:
index (int): index of the clicked prediction (0-indexed)
"""
try:
global click_df, listener_id, client_id
last_utterance_index = len(dialog_df.index)
now = get_est_now().strftime(DATE_MICROSEC_FORMAT)
new_row = [last_utterance_index, index, now]
click_df.loc[len(click_df)] = new_row
except Exception as e:
emit("error", str(e))
@socketio.on("dump_logs")
def dump_logs():
"""Store dialog, prediction, and click logs to file and clear the variables
"""
try:
global dialog_df, pred_df, click_df, current_chat_id
now = get_est_now()
date_time = now.strftime(DATETIME_FORMAT)
prefix = f"{SAVE_PATH}/{date_time}_{current_chat_id}_"
pred_df['last_utterance_index'] = pred_df['last_utterance_index'].astype(int)
pred_df['pred_index'] = pred_df['pred_index'].astype(int)
dialog_df.to_csv(prefix + "dialog.csv", index=True, columns=DIALOG_COLUMNS[:-2])
pred_df.to_csv(prefix + "pred.csv", index=False)
click_df.to_csv(prefix + "click.csv", index=False)
logger.info("Dumpped logs successfully")
except Exception as e:
emit("error", str(e))
@socketio.on("clear_session")
def clear_session():
try:
reset_session()
logger.info("Cleared session successfully")
except Exception as e:
emit("error", str(e))
@socketio.on("is_typing")
def is_typing(is_typing, is_listener):
print(is_typing, is_listener)
args = {
"is_typing": is_typing,
"is_listener": is_listener,
}
emit("is_typing", args, broadcast=True)
if __name__ == '__main__':
socketio.run(app, debug=False, host='0.0.0.0', port=8000)