# Define input layers
image_input = keras.Input(shape=(*CFG.image_size, 3), name="images")
feat_input = keras.Input(shape=(feature_space.get_encoded_features().shape[1],), name="features")
inp = {"images":image_input, "features":feat_input}
# Branch for image input
backbone = keras_cv.models.EfficientNetV2Backbone.from_preset(CFG.preset)
x1 = backbone(image_input)
x1 = keras.layers.GlobalAveragePooling2D()(x1)
x1 = keras.layers.Dropout(0.2)(x1)
# Branch for tabular/feature input
x2 = keras.layers.Dense(96, activation="selu")(feat_input)
x2 = keras.layers.Dense(128, activation="selu")(x2)
x2 = keras.layers.Dropout(0.1)(x2)
# Concatenate both branches
concat = keras.layers.Concatenate()([x1, x2])
# Output layer
out = keras.layers.Dense(1, activation="sigmoid", dtype="float32")(concat)
# Build model
model = keras.models.Model(inp, out)
# Compile the model
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=1e-4),
loss=loss,
metrics=[auc],
)
# Model Summary
model.summary()
Model: "functional_1"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│ images (InputLayer) │ (None, 128, 128, │ 0 │ - │
│ │ 3) │ │ │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ features │ (None, 71) │ 0 │ - │
│ (InputLayer) │ │ │ │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ efficient_net_v2b2… │ (None, 4, 4, │ 8,769,374 │ images[0][0] │
│ (EfficientNetV2Bac… │ 1408) │ │ │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense (Dense) │ (None, 96) │ 6,912 │ features[0][0] │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ global_average_poo… │ (None, 1408) │ 0 │ efficient_net_v2… │
│ (GlobalAveragePool… │ │ │ │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense_1 (Dense) │ (None, 128) │ 12,416 │ dense[0][0] │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dropout (Dropout) │ (None, 1408) │ 0 │ global_average_p… │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dropout_1 (Dropout) │ (None, 128) │ 0 │ dense_1[0][0] │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ concatenate_1 │ (None, 1536) │ 0 │ dropout[0][0], │
│ (Concatenate) │ │ │ dropout_1[0][0] │
├─────────────────────┼───────────────────┼────────────┼───────────────────┤
│ dense_2 (Dense) │ (None, 1) │ 1,537 │ concatenate_1[0]… │
└─────────────────────┴───────────────────┴────────────┴───────────────────┘
Total params: 8,790,239 (33.53 MB)
Trainable params: 8,707,951 (33.22 MB)
Non-trainable params: 82,288 (321.44 KB)