From 9756b1eed91d7b0d2687d5a116126e975799f0fd Mon Sep 17 00:00:00 2001 From: Antonio Cheong Date: Fri, 31 Mar 2023 11:20:15 +0800 Subject: [PATCH] revert bad fix --- src/ImageGen.py | 37 +++++++++++++++++++++++++++---------- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/src/ImageGen.py b/src/ImageGen.py index 2b1088b07..4b294c9cc 100644 --- a/src/ImageGen.py +++ b/src/ImageGen.py @@ -1,8 +1,6 @@ import json import os import time -import urllib - import regex import requests @@ -36,7 +34,7 @@ def get_images(self, prompt: str) -> list: prompt: str """ print("Sending request...") - url_encoded_prompt = urllib.parse.quote(prompt) + url_encoded_prompt = requests.utils.quote(prompt) # https://www.bing.com/images/create?q=&rt=3&FORM=GENCRE url = f"{BING_URL}/images/create?q={url_encoded_prompt}&rt=4&FORM=GENCRE" response = self.session.post(url, allow_redirects=False) @@ -56,7 +54,10 @@ def get_images(self, prompt: str) -> list: polling_url = f"{BING_URL}/images/create/async/results/{request_id}?q={url_encoded_prompt}" # Poll for results print("Waiting for results...") + start_wait = time.time() while True: + if int(time.time() - start_wait) > 300: + raise Exception("Timeout error") print(".", end="", flush=True) response = self.session.get(polling_url) if response.status_code != 200: @@ -72,7 +73,20 @@ def get_images(self, prompt: str) -> list: # Remove size limit normal_image_links = [link.split("?w=")[0] for link in image_links] # Remove duplicates - return list(set(normal_image_links)) + normal_image_links = list(set(normal_image_links)) + + # Bad images + bad_images = [ + "https://r.bing.com/rp/in-2zU3AJUdkgFe7ZKv19yPBHVs.png", + "https://r.bing.com/rp/TX9QuO3WzcCJz1uaaSwQAz39Kb0.jpg", + ] + for im in normal_image_links: + if im in bad_images: + raise Exception("Bad images") + # No images + if not normal_image_links: + raise Exception("No images") + return normal_image_links def save_images(self, links: list, output_dir: str) -> None: """ @@ -120,13 +134,16 @@ def save_images(self, links: list, output_dir: str) -> None: ) args = parser.parse_args() # Load auth cookie - with open(args.cookie_file, encoding="utf-8") as file: - cookie_json = json.load(file) - for cookie in cookie_json: - if cookie.get("name") == "_U": - args.U = cookie.get("value") - break + try: + with open(args.cookie_file, encoding="utf-8") as file: + cookie_json = json.load(file) + for cookie in cookie_json: + if cookie.get("name") == "_U": + args.U = cookie.get("value") + break + except: + pass if args.U is None: raise Exception("Could not find auth cookie")