# train the model on gpu if it's available
device = "cuda:0" if ivy.gpu_is_available() else "cpu"
# training hyperparams
optimizer= ivy.Adam(1e-4)
batch_size = 64
num_epochs = 20
num_classes = 10
model = IvyNet(
h_w=(28, 28),
input_channels=1,
output_channels=120,
num_classes=num_classes,
device=device,
)
model_name = type(model).__name__.lower()
# training loop
def train(images, classes, epochs, model, device, num_classes=10, batch_size=32):
# training metrics
epoch_loss = 0.0
running_loss = 0.0
fields = ["epoch", "epoch_loss", "training_accuracy"]
metrics = []
dataset_size = len(images)
for epoch in range(epochs):
train_loss, train_correct = 0, 0
train_loop = tqdm(
generate_batches(images, classes, len(images), batch_size=batch_size),
total=dataset_size // batch_size,
position=0,
leave=True,
)
for xbatch, ybatch in train_loop:
if device != "cpu":
xbatch, ybatch = xbatch.to_device("gpu:0"), ybatch.to_device("gpu:0")
# since the cross entropy function expects the target classes to be in one-hot encoded format
ybatch_encoded = ivy.one_hot(ybatch, num_classes)
# update model params
loss_probs, grads = ivy.execute_with_gradients(
loss_fn,
(model.v, model, xbatch, ybatch_encoded),
)
model.v = optimizer.step(model.v, grads["0"])
batch_loss = ivy.to_numpy(loss_probs[0]).mean().item() # batch mean loss
epoch_loss += batch_loss * xbatch.shape[0]
train_correct += num_correct(loss_probs[1], ybatch)
train_loop.set_description(f"Epoch [{epoch + 1:2d}/{epochs}]")
train_loop.set_postfix(
running_loss=batch_loss,
accuracy_percentage=(train_correct / dataset_size) * 100,
)
epoch_loss = epoch_loss / dataset_size
training_accuracy = train_correct / dataset_size
metrics.append([epoch, epoch_loss, training_accuracy])
train_loop.write(
f"\nAverage training loss: {epoch_loss:.6f}, Train Correct: {train_correct}",
end="\n",
)
# write metrics for plotting
with open(f"/{model_name}_train_summary.csv", "w") as f:
f = csv.writer(f)
f.writerow(fields)
f.writerows(metrics)
# assuming the dataset(images and classes) are already prepared in a folder
train(images, classes, num_epochs, model, device, num_classes = num_classes, batch_size = batch_size)