A Scalable Probabilistic Programming Language With JAX
Probabilistic programming languages (PPLs) have gained widespread popularity for Bayesian modelling, with established frameworks like Stan and PyMC becoming standard tools across domains. However, while recent research in variational inference has produced scalable methods capable of capturing complex non-Gaussian posteriors, most notably normalizing flows, these advances have yet to be integrated into the core workflows of mainstream PPLs. We introduce Bayinx, a PPL embedded in Python leveraging JAX for scalable modelling on both CPUs and GPUs. Bayinx focuses on normalizing flows to provide a high-fidelity variational approximation that serves as a performant alternative to Markov Chain Monte Carlo. Furthermore, Bayinx enables users to treat arbitrary Python objects as parameters, allowing seamless integration of custom classes using other JAX libraries into probabilistic models. We demonstrate these capabilities through various modelling examples.
Langage de programmation probabiliste évolutif avec JAX
Les langages de programmation probabiliste (PPL) ont acquis une grande popularité pour la modélisation bayésienne, des cadres bien établis tels que Stan et PyMC devenant des outils standard dans divers domaines. Cependant, alors que les recherches récentes en inférence variationnelle ont produit des méthodes évolutives capables de capturer des postérieurs non gaussiens complexes, notamment les flux de normalisation, ces avancées n'ont pas encore été intégrées aux flux de travail de base des PPL courants. Nous vous présentons Bayinx, un PPL intégré à Python qui exploite JAX pour une modélisation évolutive à la fois sur les processeurs et les processeurs graphiques. Bayinx normalise les flux pour fournir une approximation variationnelle haute fidélité qui constitue une alternative puissant à la méthode de Monte Carlo par chaînes de Markov. De plus, Bayinx permet aux utilisateurs de traiter des objets Python arbitraires comme des paramètres, offrant ainsi une intégration transparente de classes personnalisées utilisant d'autres bibliothèques JAX dans des modèles probabilistes. Nous démontrons ces capacités à travers divers exemples de modélisation.
Date and Time
-
Language of Oral Presentation
English
Language of Visual Aids
English