diff --git a/.gitignore b/.gitignore index 14ec8ef205..56cd7649a4 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,6 @@ docs/venv # vi backups *~ + +# development +.vscode \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 1e6e23fd44..1c15a5513e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -26,7 +26,7 @@ If you're new we encourage you to take a look at issues tagged with [good first 5. Verify that there are no issues in your doc build. You can check preview locally by installing [sphinx-serve](https://pypi.org/project/sphinx-serve/) and - then running `sphinx-serve -d build`. + then running `sphinx-serve -b build`. 5. Ensure your test passes locally 6. If you haven't already, complete the Contributor License Agreement ("CLA"). diff --git a/docs/source/index.rst b/docs/source/index.rst index a5a89d8644..dffc26ab11 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -17,6 +17,17 @@ experiment with PyTorch. --- + Measuring Similarity using Siamese Network + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + This example demonstrates how to measure similarity between two images + using `Siamese network `__ + on the `MNIST `__ database. + + `GO TO EXAMPLE `__ :opticon:`link-external` + + --- + Word-level Language Modeling using RNN and Transformer ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/run_python_examples.sh b/run_python_examples.sh index 10fd427bb6..7244ff7ccf 100755 --- a/run_python_examples.sh +++ b/run_python_examples.sh @@ -110,6 +110,11 @@ function regression() { python main.py --epochs 1 $CUDA_FLAG || error "regression failed" } +function siamese_network() { + start + python main.py --epochs 1 --dry-run || error "siamese network example failed" +} + function reinforcement_learning() { start python reinforce.py || error "reinforcement learning reinforce failed" @@ -193,6 +198,7 @@ function run_all() { mnist_hogwild regression reinforcement_learning + siamese_network super_resolution time_sequence_prediction vae diff --git a/siamese_network/README.md b/siamese_network/README.md new file mode 100644 index 0000000000..973a0414a4 --- /dev/null +++ b/siamese_network/README.md @@ -0,0 +1,7 @@ +# Siamese Network Example + +```bash +pip install -r requirements.txt +python main.py +# CUDA_VISIBLE_DEVICES=2 python main.py # to specify GPU id to ex. 2 +``` diff --git a/siamese_network/main.py b/siamese_network/main.py new file mode 100644 index 0000000000..33a5f71517 --- /dev/null +++ b/siamese_network/main.py @@ -0,0 +1,296 @@ +from __future__ import print_function +import argparse, random, copy +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import torchvision +from torch.utils.data import Dataset +from torchvision import datasets +from torchvision import transforms as T +from torch.optim.lr_scheduler import StepLR + + +class SiameseNetwork(nn.Module): + """ + Siamese network for image similarity estimation. + The network is composed of two identical networks, one for each input. + The output of each network is concatenated and passed to a linear layer. + The output of the linear layer passed through a sigmoid function. + `"FaceNet" `_ is a variant of the Siamese network. + This implementation varies from FaceNet as we use the `ResNet-18` model from + `"Deep Residual Learning for Image Recognition" `_ as our feature extractor. + In addition, we aren't using `TripletLoss` as the MNIST dataset is simple, so `BCELoss` can do the trick. + """ + def __init__(self): + super(SiameseNetwork, self).__init__() + # get resnet model + self.resnet = torchvision.models.resnet18(pretrained=False) + + # over-write the first conv layer to be able to read MNIST images + # as resnet18 reads (3,x,x) where 3 is RGB channels + # whereas MNIST has (1,x,x) where 1 is a gray-scale channel + self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) + self.fc_in_features = self.resnet.fc.in_features + + # remove the last layer of resnet18 (linear layer which is before avgpool layer) + self.resnet = torch.nn.Sequential(*(list(self.resnet.children())[:-1])) + + # add linear layers to compare between the features of the two images + self.fc = nn.Sequential( + nn.Linear(self.fc_in_features * 2, 256), + nn.ReLU(inplace=True), + nn.Linear(256, 1), + ) + + self.sigmoid = nn.Sigmoid() + + # initialize the weights + self.resnet.apply(self.init_weights) + self.fc.apply(self.init_weights) + + def init_weights(self, m): + if isinstance(m, nn.Linear): + torch.nn.init.xavier_uniform(m.weight) + m.bias.data.fill_(0.01) + + def forward_once(self, x): + output = self.resnet(x) + output = output.view(output.size()[0], -1) + return output + + def forward(self, input1, input2): + # get two images' features + output1 = self.forward_once(input1) + output2 = self.forward_once(input2) + + # concatenate both images' features + output = torch.cat((output1, output2), 1) + + # pass the concatenation to the linear layers + output = self.fc(output) + + # pass the out of the linear layers to sigmoid layer + output = self.sigmoid(output) + + return output + +class APP_MATCHER(Dataset): + def __init__(self, root, train, download=False): + super(APP_MATCHER, self).__init__() + + # get MNIST dataset + self.dataset = datasets.MNIST(root, train=train, download=download) + + # as `self.dataset.data`'s shape is (Nx28x28), where N is the number of + # examples in MNIST dataset, a single example has the dimensions of + # (28x28) for (WxH), where W and H are the width and the height of the image. + # However, every example should have (CxWxH) dimensions where C is the number + # of channels to be passed to the network. As MNIST contains gray-scale images, + # we add an additional dimension to corresponds to the number of channels. + self.data = self.dataset.data.unsqueeze(1).clone() + + self.group_examples() + + def group_examples(self): + """ + To ease the accessibility of data based on the class, we will use `group_examples` to group + examples based on class. + + Every key in `grouped_examples` corresponds to a class in MNIST dataset. For every key in + `grouped_examples`, every value will conform to all of the indices for the MNIST + dataset examples that correspond to that key. + """ + + # get the targets from MNIST dataset + np_arr = np.array(self.dataset.targets.clone()) + + # group examples based on class + self.grouped_examples = {} + for i in range(0,10): + self.grouped_examples[i] = np.where((np_arr==i))[0] + + def __len__(self): + return self.data.shape[0] + + def __getitem__(self, index): + """ + For every example, we will select two images. There are two cases, + positive and negative examples. For positive examples, we will have two + images from the same class. For negative examples, we will have two images + from different classes. + + Given an index, if the index is even, we will pick the second image from the same class, + but it won't be the same image we chose for the first class. This is used to ensure the positive + example isn't trivial as the network would easily distinguish the similarity between same images. However, + if the network were given two different images from the same class, the network will need to learn + the similarity between two different images representing the same class. If the index is odd, we will + pick the second image from a different class than the first image. + """ + + # pick some random class for the first image + selected_class = random.randint(0, 9) + + # pick a random index for the first image in the grouped indices based of the label + # of the class + random_index_1 = random.randint(0, self.grouped_examples[selected_class].shape[0]-1) + + # pick the index to get the first image + index_1 = self.grouped_examples[selected_class][random_index_1] + + # get the first image + image_1 = self.data[index_1].clone().float() + + # same class + if index % 2 == 0: + # pick a random index for the second image + random_index_2 = random.randint(0, self.grouped_examples[selected_class].shape[0]-1) + + # ensure that the index of the second image isn't the same as the first image + while random_index_2 == random_index_1: + random_index_2 = random.randint(0, self.grouped_examples[selected_class].shape[0]-1) + + # pick the index to get the second image + index_2 = self.grouped_examples[selected_class][random_index_2] + + # get the second image + image_2 = self.data[index_2].clone().float() + + # set the label for this example to be positive (1) + target = torch.tensor(1, dtype=torch.float) + + # different class + else: + # pick a random class + other_selected_class = random.randint(0, 9) + + # ensure that the class of the second image isn't the same as the first image + while other_selected_class == selected_class: + other_selected_class = random.randint(0, 9) + + + # pick a random index for the second image in the grouped indices based of the label + # of the class + random_index_2 = random.randint(0, self.grouped_examples[other_selected_class].shape[0]-1) + + # pick the index to get the second image + index_2 = self.grouped_examples[other_selected_class][random_index_2] + + # get the second image + image_2 = self.data[index_2].clone().float() + + # set the label for this example to be negative (0) + target = torch.tensor(0, dtype=torch.float) + + return image_1, image_2, target + + +def train(args, model, device, train_loader, optimizer, epoch): + model.train() + + # we aren't using `TripletLoss` as the MNIST dataset is simple, so `BCELoss` can do the trick. + criterion = nn.BCELoss() + + for batch_idx, (images_1, images_2, targets) in enumerate(train_loader): + images_1, images_2, targets = images_1.to(device), images_2.to(device), targets.to(device) + optimizer.zero_grad() + outputs = model(images_1, images_2).squeeze() + loss = criterion(outputs, targets) + loss.backward() + optimizer.step() + if batch_idx % args.log_interval == 0: + print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( + epoch, batch_idx * len(images_1), len(train_loader.dataset), + 100. * batch_idx / len(train_loader), loss.item())) + if args.dry_run: + break + + +def test(model, device, test_loader): + model.eval() + test_loss = 0 + correct = 0 + + # we aren't using `TripletLoss` as the MNIST dataset is simple, so `BCELoss` can do the trick. + criterion = nn.BCELoss() + + with torch.no_grad(): + for (images_1, images_2, targets) in test_loader: + images_1, images_2, targets = images_1.to(device), images_2.to(device), targets.to(device) + outputs = model(images_1, images_2).squeeze() + test_loss += criterion(outputs, targets).sum().item() # sum up batch loss + pred = torch.where(outputs > 0.5, 1, 0) # get the index of the max log-probability + correct += pred.eq(targets.view_as(pred)).sum().item() + + test_loss /= len(test_loader.dataset) + + # for the 1st epoch, the average loss is 0.0001 and the accuracy 97-98% + # using default settings. After completing the 10th epoch, the average + # loss is 0.0000 and the accuracy 99.5-100% using default settings. + print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( + test_loss, correct, len(test_loader.dataset), + 100. * correct / len(test_loader.dataset))) + + +def main(): + # Training settings + parser = argparse.ArgumentParser(description='PyTorch Siamese network Example') + parser.add_argument('--batch-size', type=int, default=64, metavar='N', + help='input batch size for training (default: 64)') + parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', + help='input batch size for testing (default: 1000)') + parser.add_argument('--epochs', type=int, default=14, metavar='N', + help='number of epochs to train (default: 14)') + parser.add_argument('--lr', type=float, default=1.0, metavar='LR', + help='learning rate (default: 1.0)') + parser.add_argument('--gamma', type=float, default=0.7, metavar='M', + help='Learning rate step gamma (default: 0.7)') + parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA training') + parser.add_argument('--dry-run', action='store_true', default=False, + help='quickly check a single pass') + parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') + parser.add_argument('--log-interval', type=int, default=10, metavar='N', + help='how many batches to wait before logging training status') + parser.add_argument('--save-model', action='store_true', default=False, + help='For Saving the current Model') + args = parser.parse_args() + + use_cuda = not args.no_cuda and torch.cuda.is_available() + + torch.manual_seed(args.seed) + + device = torch.device("cuda" if use_cuda else "cpu") + + train_kwargs = {'batch_size': args.batch_size} + test_kwargs = {'batch_size': args.test_batch_size} + if use_cuda: + cuda_kwargs = {'num_workers': 1, + 'pin_memory': True, + 'shuffle': True} + train_kwargs.update(cuda_kwargs) + test_kwargs.update(cuda_kwargs) + + train_dataset = APP_MATCHER('../data', train=True, download=True) + test_dataset = APP_MATCHER('../data', train=False) + train_loader = torch.utils.data.DataLoader(train_dataset,**train_kwargs) + test_loader = torch.utils.data.DataLoader(test_dataset, **test_kwargs) + + model = SiameseNetwork().to(device) + optimizer = optim.Adadelta(model.parameters(), lr=args.lr) + + scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) + for epoch in range(1, args.epochs + 1): + train(args, model, device, train_loader, optimizer, epoch) + test(model, device, test_loader) + scheduler.step() + + if args.save_model: + torch.save(model.state_dict(), "siamese_network.pt") + + +if __name__ == '__main__': + main() diff --git a/siamese_network/requirements.txt b/siamese_network/requirements.txt new file mode 100644 index 0000000000..ac988bdf84 --- /dev/null +++ b/siamese_network/requirements.txt @@ -0,0 +1,2 @@ +torch +torchvision