Skip to content
This repository has been archived by the owner on Aug 3, 2023. It is now read-only.

Commit

Permalink
revert bad fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Antonio Cheong committed Mar 31, 2023
1 parent 71e958c commit 9756b1e
Showing 1 changed file with 27 additions and 10 deletions.
37 changes: 27 additions & 10 deletions src/ImageGen.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import json
import os
import time
import urllib

import regex
import requests

Expand Down Expand Up @@ -36,7 +34,7 @@ def get_images(self, prompt: str) -> list:
prompt: str
"""
print("Sending request...")
url_encoded_prompt = urllib.parse.quote(prompt)
url_encoded_prompt = requests.utils.quote(prompt)
# https://www.bing.com/images/create?q=<PROMPT>&rt=3&FORM=GENCRE
url = f"{BING_URL}/images/create?q={url_encoded_prompt}&rt=4&FORM=GENCRE"
response = self.session.post(url, allow_redirects=False)
Expand All @@ -56,7 +54,10 @@ def get_images(self, prompt: str) -> list:
polling_url = f"{BING_URL}/images/create/async/results/{request_id}?q={url_encoded_prompt}"
# Poll for results
print("Waiting for results...")
start_wait = time.time()
while True:
if int(time.time() - start_wait) > 300:
raise Exception("Timeout error")
print(".", end="", flush=True)
response = self.session.get(polling_url)
if response.status_code != 200:
Expand All @@ -72,7 +73,20 @@ def get_images(self, prompt: str) -> list:
# Remove size limit
normal_image_links = [link.split("?w=")[0] for link in image_links]
# Remove duplicates
return list(set(normal_image_links))
normal_image_links = list(set(normal_image_links))

# Bad images
bad_images = [
"https://r.bing.com/rp/in-2zU3AJUdkgFe7ZKv19yPBHVs.png",
"https://r.bing.com/rp/TX9QuO3WzcCJz1uaaSwQAz39Kb0.jpg",
]
for im in normal_image_links:
if im in bad_images:
raise Exception("Bad images")
# No images
if not normal_image_links:
raise Exception("No images")
return normal_image_links

def save_images(self, links: list, output_dir: str) -> None:
"""
Expand Down Expand Up @@ -120,13 +134,16 @@ def save_images(self, links: list, output_dir: str) -> None:
)
args = parser.parse_args()
# Load auth cookie
with open(args.cookie_file, encoding="utf-8") as file:
cookie_json = json.load(file)
for cookie in cookie_json:
if cookie.get("name") == "_U":
args.U = cookie.get("value")
break
try:
with open(args.cookie_file, encoding="utf-8") as file:
cookie_json = json.load(file)
for cookie in cookie_json:
if cookie.get("name") == "_U":
args.U = cookie.get("value")
break

except:
pass
if args.U is None:
raise Exception("Could not find auth cookie")

Expand Down

0 comments on commit 9756b1e

Please sign in to comment.