google

Differentially private machine learning at scale with JAX-Privacy (opens in new tab)

Google DeepMind and Google Research have announced the release of JAX-Privacy 1.0, a high-performance library designed to scale differentially private (DP) machine learning. By leveraging JAX’s native parallelization and functional programming model, the toolkit enables researchers to train large-scale foundation models while maintaining rigorous privacy guarantees. This version introduces modular components for advanced algorithms and empirical auditing, making private training both computationally efficient and verifiable across distributed environments.

Scaling Differential Privacy with JAX

  • The library is built directly on the JAX ecosystem, integrating seamlessly with Flax for neural network architectures and Optax for optimization.
  • It utilizes JAX’s vmap for automatic vectorization and shard_map for single-program multiple-data (SPMD) parallelization, allowing DP primitives to scale across multiple accelerators.
  • By using just-in-time (JIT) compilation, the library mitigates the traditional performance overhead associated with per-example gradient clipping and noise addition.

Core Components and Advanced Algorithms

  • The toolkit provides fundamental building blocks for implementing standard DP algorithms like DP-SGD and DP-FTRL, including specialized modules for data batch construction.
  • It supports state-of-the-art methods such as DP matrix factorization, which improves performance by injecting correlated noise across training iterations.
  • Features like micro-batching and padding are included to handle the massive, variable-sized batches often required to achieve an optimal balance between privacy and model utility.

Verification and Privacy Auditing

  • JAX-Privacy incorporates rigorous privacy accounting based on Rényi Differential Privacy to provide precise tracking of privacy budgets.
  • The library includes tools for empirical auditing, allowing developers to validate their privacy guarantees through techniques like membership inference attacks and data poisoning.
  • The design ensures correctness in distributed settings, specifically focusing on consistent noise generation and gradient synchronization across clusters.

JAX-Privacy 1.0 is a robust solution for researchers and engineers who need to deploy production-grade private models. Its modular architecture and integration with high-performance computing primitives make it a primary choice for training foundation models on sensitive datasets without compromising on scalability or security.