-
Notifications
You must be signed in to change notification settings - Fork 0
/
cop_web.py
82 lines (65 loc) · 2.3 KB
/
cop_web.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import io
from PIL import Image
import streamlit as st
import torch
from torchvision import transforms
import base64
def add_bg_from_local(image_file):
with open(image_file, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
st.markdown(
f"""
<style>
.stApp {{
background-image: url(data:images/{"jpg"};base64,{encoded_string.decode()});
background-size: cover
}}
</style>
""",
unsafe_allow_html=True
)
add_bg_from_local('images/bg2.jpg')
MODEL_PATH = 'https://drive.google.com/file/d/1Smm9ZSsv1gZJgoKft-YcFpsjhG5rZaS3/view?usp=share_link'
LABELS_PATH = 'https://drive.google.com/file/d/1CU0_KIRA65vTK6MQABFCoGwBhwHMgQqf/view?usp=share_link'
def load_image():
uploaded_file = st.file_uploader(label='Pick a banknote to test')
if uploaded_file is not None:
image_data = uploaded_file.getvalue()
st.image(image_data)
return Image.open(io.BytesIO(image_data))
else:
return None
def load_model(model_path):
model = torch.load(model_path, map_location='cpu')
model.eval()
return model
def load_labels(labels_file):
with open(labels_file, "r") as f:
categories = [s.strip() for s in f.readlines()]
return categories
def predict(model, categories, image):
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(image)
input_batch = input_tensor.unsqueeze(0)
with torch.no_grad():
output = model(input_batch)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
all_prob, all_catid = torch.topk(probabilities, len(categories))
for i in range(all_prob.size(0)):
st.write(categories[all_catid[i]], all_prob[i].item())
def main():
st.title('Colombian Pesu banknote Detection')
model = load_model(MODEL_PATH)
categories = load_labels(LABELS_PATH)
image = load_image()
result = st.button('Predict image')
if result:
st.write('Checking...')
predict(model, categories, image)
if __name__ == '__main__':
main()