Federated Black-Box Adaptation for Semantic Segmentation
Jay N. Paranjape
Shameema Sikder
S. Swaroop Vedula
Vishal M. Patel
Johns Hopkins University
[Paper]
[GitHub]

Abstract

Federated Learning (FL) is a form of distributed learning that allows multiple institutions or clients to collaboratively learn a global model to solve a task. This allows the model to utilize the information from every institute while preserving data privacy. However, recent studies show that the promise of protecting the privacy of data is not upheld by existing methods and that it is possible to recreate the training data from the different institutions. This is done by utilizing gradients transferred between the clients and the global server during training or by knowing the model architecture at the client end. In this paper, we propose a federated learning framework for semantic segmentation without knowing the model architecture nor transferring gradients between the client and the server, thus enabling better privacy preservation. We propose \textit{BlackFed} - a black-box adaptation of neural networks that utilizes zero order optimization (ZOO) to update the client model weights and first order optimization (FOO) to update the server weights. We evaluate our approach on several computer vision and medical imaging datasets to demonstrate its effectiveness. To the best of our knowledge, this work is one of the first works in employing federated learning for segmentation, devoid of gradients or model information exchange.


Method

Architecture: Each of the clients have their own network which generates features given the data at the client. These client models are independent of each other, except that the last layer should produce featuer maps of the same shape. The network on the server side takes the featuers generated from the client as input and computes the segmentation mask using a much heavier network, like a Segformer.

Training: Training occurs in a round-robin fashion. In each round, one client is chosen. It computes a forward pass using its data and passes features to the server. The server further computes the predicted mask and computes a loss between the label and the prediction. Next the client weights are updated using the zero order optimization methods. Followed by this, the server weights are updated using first order optimization. This is continued for several rounds, with each client, thus indirectly benefitting from the server being trained using all clients. However, this can lead to the common problem of catastrophic forgetting where training with a client can make the server forget information from the previous client. Hence, after training the server in each round, it updates a hashmap with the key being the client index, and the value being the server weights after training with the client.

Inference: During inference, the client passes its feature maps to the server. The server uses the index of the client to query the hashmap and retrieves the server weights which were last trained with that client. It then computes the predicted segmentation mask. Note that since the hashmap is only updated duringtraining and used during inference, it makes sure that the server weights get trained from every client. After multiple rounds, it learns a rich set of representations.


Quantitative Results

Comparison of BlackFed against individual training. The third and fourth columns denote testing with the local test data, while the fifth and sixth columns denote OOD testing. Our method improves OOD performance of clients without harming their local performance.


Paper and Supplementary Material


Federated Black-Box Adaptation for Semantic Segmentation

(hosted on ArXiv)


[Bibtex]


Acknowledgements

This template was originally made by Phillip Isola and Richard Zhang for a colorful ECCV project; the code can be found here.