Open
Description
Hi, I want to add an example for Siamese network, since it is one of the popular use cases in ML. I am thinking of implementing it in a way similar to other examples viz. command line arguments to choose which dataset to train, hyperparameters etc.
Is there something I need to keep in mind specifically apart from these:
- Use torchvision's Dataset class and PyTorch's DataLoader class to handle data.
- Implement a simple CNN as a nn.Module subclass
- Implement triplet loss
- Create train and test functions and a main function that calls those 2 methods at each epoch.
- Report final loss and accuracy
Is this something that is worth adding to the repository.