Architecture: Each of the clients have their own network which generates features given the data at the client.
These client models are independent of each other, except that the last layer should produce featuer maps of the same shape.
The network on the server side takes the featuers generated from the client as input and computes the segmentation mask using a much heavier network, like a Segformer.
Training: Training occurs in a round-robin fashion. In each round, one client is chosen. It computes a forward pass using its data and passes features to the server.
The server further computes the predicted mask and computes a loss between the label and the prediction. Next the client weights are updated using the zero order optimization
methods. Followed by this, the server weights are updated using first order optimization. This is continued for several rounds, with each client, thus indirectly benefitting from the server being trained using all clients.
However, this can lead to the common problem of catastrophic forgetting where training with a client can make the server forget information from the previous client. Hence, after training the server in each round, it updates a
hashmap with the key being the client index, and the value being the server weights after training with the client.
Inference: During inference, the client passes its feature maps to the server. The server uses the index of the client to query
the hashmap and retrieves the server weights which were last trained with that client. It then computes the predicted segmentation mask. Note that since the hashmap is only updated duringtraining and used during inference, it makes sure
that the server weights get trained from every client. After multiple rounds, it learns a rich set of representations.
|