-
Notifications
You must be signed in to change notification settings - Fork 0
/
utility.py
77 lines (61 loc) · 2.59 KB
/
utility.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import zipfile
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import MobileNetV2, DenseNet169, ResNet50, VGG19
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout
from tensorflow.keras.optimizers import Adam
#_____________________________________________________________________________________________
def unzip_data(zip_file_path, extract_dir):
'''
zip_file_path: directory to the zip file
extract_dir: destination directory
'''
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
# Extract all the contents into the specified directory
zip_ref.extractall(extract_dir)
#_____________________________________________________________________________________________
def generator(inp_dir):
'''
inp_dir: directory to the file containing COVID and Normal folders
'''
training_datagen = ImageDataGenerator(rescale = 1./255)
train_generator = training_datagen.flow_from_directory(
inp_dir,
target_size=(224,224),
class_mode='binary',
batch_size=32
)
return train_generator
#_____________________________________________________________________________________________
def set_model(model_name):
'''
model_name: can be one of the following values:
[mobilenet2, densenet, resnet50, vgg19]
'''
if model_name == 'mobilenet2':
feature_model = MobileNetV2(include_top=False,
input_shape = (224,224,3))
elif model_name == 'densenet':
feature_model = DenseNet169(include_top=False,
input_shape = (224,224,3))
elif model_name == 'resnet50':
feature_model = ResNet50(include_top=False,
input_shape = (224,224,3))
elif model_name == 'vgg19':
feature_model = VGG19(include_top=False,
input_shape = (224,224,3))
else:
print('Please enter a valid model name from the following list! \n mobilenet2, densenet, resnet50, vgg19')
return
# Flatten the output layer of feature_model to 1 dimension
x = GlobalAveragePooling2D(name = 'GAP')(feature_model.output)
# Add a fully connected layer with 4 hidden units and ReLU activation
x = Dense(4, activation='relu', name = 'dense4')(x)
# Add a dropout rate of 0.25
x = Dropout(0.25, name = 'drop')(x)
# Add a final sigmoid layer for classification
x = Dense(1, activation='sigmoid', name='dense1')(x)
# Append the dense network to the feature_model
model = tf.keras.Model(feature_model.input, x)
return model
#_____________________________________________________________________________________________