Home | Research | Groups | Stefanie Jegelka

Research Group Stefanie Jegelka

Link to Stefanie Jegelka

Stefanie Jegelka

Prof. Dr.

Principal Investigator

Foundations of Deep Neural Networks

Stefanie Jegelka

is a Humboldt Professor at TU Munich.

Her research is in algorithmic machine learning, and spans modeling, optimization algorithms, theory and applications. In particular, she has been working on exploiting mathematical structure for discrete and combinatorial machine learning problems, for robustness and for scaling machine learning algorithms.

Team members @MCML

Link to Andreas Bergmeister

Andreas Bergmeister

Foundations of Deep Neural Networks

Link to Valerie Engelmayer

Valerie Engelmayer

Foundations of Deep Neural Networks

Link to Eduardo Santos Escriche

Eduardo Santos Escriche

Foundations of Deep Neural Networks

Publications @MCML

[9]
G. Ma, Y. Wang, D. Lim, S. Jegelka and Y. Wang.
A Canonicalization Perspective on Invariant and Equivariant Learning.
NeurIPS 2024 - 38th Conference on Neural Information Processing Systems. Vancouver, Canada, Dec 10-15, 2024. To be published. Preprint available. arXiv. GitHub.
Abstract

In many applications, we desire neural networks to exhibit invariance or equivariance to certain groups due to symmetries inherent in the data. Recently, frame-averaging methods emerged to be a unified framework for attaining symmetries efficiently by averaging over input-dependent subsets of the group, i.e., frames. What we currently lack is a principled understanding of the design of frames. In this work, we introduce a canonicalization perspective that provides an essential and complete view of the design of frames. Canonicalization is a classic approach for attaining invariance by mapping inputs to their canonical forms. We show that there exists an inherent connection between frames and canonical forms. Leveraging this connection, we can efficiently compare the complexity of frames as well as determine the optimality of certain frames. Guided by this principle, we design novel frames for eigenvectors that are strictly superior to existing methods – some are even optimal – both theoretically and empirically. The reduction to the canonicalization perspective further uncovers equivalences between previous methods. These observations suggest that canonicalization provides a fundamental understanding of existing frame-averaging methods and unifies existing equivariant and invariant learning methods.

MCML Authors
Link to Stefanie Jegelka

Stefanie Jegelka

Prof. Dr.

Foundations of Deep Neural Networks


[8]
Y. Wang, K. Hu, S. Gupta, Z. Ye, Y. Wang and S. Jegelka.
Understanding the Role of Equivariance in Self-supervised Learning.
NeurIPS 2024 - 38th Conference on Neural Information Processing Systems. Vancouver, Canada, Dec 10-15, 2024. To be published. Preprint available. arXiv. GitHub.
Abstract

Contrastive learning has been a leading paradigm for self-supervised learning, but it is widely observed that it comes at the price of sacrificing useful features (eg colors) by being invariant to data augmentations. Given this limitation, there has been a surge of interest in equivariant self-supervised learning (E-SSL) that learns features to be augmentation-aware. However, even for the simplest rotation prediction method, there is a lack of rigorous understanding of why, when, and how E-SSL learns useful features for downstream tasks. To bridge this gap between practice and theory, we establish an information-theoretic perspective to understand the generalization ability of E-SSL. In particular, we identify a critical explaining-away effect in E-SSL that creates a synergy between the equivariant and classification tasks. This synergy effect encourages models to extract class-relevant features to improve its equivariant prediction, which, in turn, benefits downstream tasks requiring semantic features. Based on this perspective, we theoretically analyze the influence of data transformations and reveal several principles for practical designs of E-SSL. Our theory not only aligns well with existing E-SSL methods but also sheds light on new directions by exploring the benefits of model equivariance. We believe that a theoretically grounded understanding on the role of equivariance would inspire more principled and advanced designs in this field.

MCML Authors
Link to Stefanie Jegelka

Stefanie Jegelka

Prof. Dr.

Foundations of Deep Neural Networks


[7]
Y. Wang, Y. Wu, Z. Wei, S. Jegelka and Y. Wang.
A Theoretical Understanding of Self-Correction through In-context Alignment.
NeurIPS 2024 - 38th Conference on Neural Information Processing Systems. Vancouver, Canada, Dec 10-15, 2024. To be published. Preprint available. arXiv.
Abstract

Going beyond mimicking limited human experiences, recent studies show initial evidence that, like humans, large language models (LLMs) are capable of improving their abilities purely by self-correction, i.e., correcting previous responses through self-examination, in certain circumstances. Nevertheless, little is known about how such capabilities arise. In this work, based on a simplified setup akin to an alignment task, we theoretically analyze self-correction from an in-context learning perspective, showing that when LLMs give relatively accurate self-examinations as rewards, they are capable of refining responses in an in-context way. Notably, going beyond previous theories on over-simplified linear transformers, our theoretical construction underpins the roles of several key designs of realistic transformers for self-correction: softmax attention, multi-head attention, and the MLP block. We validate these findings extensively on synthetic datasets. Inspired by these findings, we also illustrate novel applications of self-correction, such as defending against LLM jailbreaks, where a simple self-correction step does make a large difference. We believe that these findings will inspire further research on understanding, exploiting, and enhancing self-correction for building better foundation models.

MCML Authors
Link to Stefanie Jegelka

Stefanie Jegelka

Prof. Dr.

Foundations of Deep Neural Networks


[6]
M. Yau, N. Karalias, E. Lu, J. Xu and S. Jegelka.
Are Graph Neural Networks Optimal Approximation Algorithms?.
NeurIPS 2024 - 38th Conference on Neural Information Processing Systems. Vancouver, Canada, Dec 10-15, 2024. To be published. Preprint available. arXiv.
Abstract

In this work we design graph neural network architectures that capture optimal approximation algorithms for a large class of combinatorial optimization problems, using powerful algorithmic tools from semidefinite programming (SDP). Concretely, we prove that polynomial-sized message-passing algorithms can represent the most powerful polynomial time algorithms for Max Constraint Satisfaction Problems assuming the Unique Games Conjecture. We leverage this result to construct efficient graph neural network architectures, OptGNN, that obtain high-quality approximate solutions on landmark combinatorial optimization problems such as Max-Cut, Min-Vertex-Cover, and Max-3-SAT. Our approach achieves strong empirical results across a wide range of real-world and synthetic datasets against solvers and neural baselines. Finally, we take advantage of OptGNN’s ability to capture convex relaxations to design an algorithm for producing bounds on the optimal solution from the learned embeddings of OptGNN.

MCML Authors
Link to Stefanie Jegelka

Stefanie Jegelka

Prof. Dr.

Foundations of Deep Neural Networks


[5]
L. Fang, Y. Wang, Z. Liu, C. Zhang, S. Jegelka, J. Gao, B. Ding and Y. Wang.
What is Wrong with Perplexity for Long-context Language Modeling?.
Preprint (Oct. 2024). arXiv. GitHub.
Abstract

Handling long-context inputs is crucial for large language models (LLMs) in tasks such as extended conversations, document summarization, and many-shot in-context learning. While recent approaches have extended the context windows of LLMs and employed perplexity (PPL) as a standard evaluation metric, PPL has proven unreliable for assessing long-context capabilities. The underlying cause of this limitation has remained unclear. In this work, we provide a comprehensive explanation for this issue. We find that PPL overlooks key tokens, which are essential for long-context understanding, by averaging across all tokens and thereby obscuring the true performance of models in long-context scenarios. To address this, we propose textbf{LongPPL}, a novel metric that focuses on key tokens by employing a long-short context contrastive method to identify them. Our experiments demonstrate that LongPPL strongly correlates with performance on various long-context benchmarks (e.g., Pearson correlation of -0.96), significantly outperforming traditional PPL in predictive accuracy. Additionally, we introduce textbf{LongCE} (Long-context Cross-Entropy) loss, a re-weighting strategy for fine-tuning that prioritizes key tokens, leading to consistent improvements across diverse benchmarks. In summary, these contributions offer deeper insights into the limitations of PPL and present effective solutions for accurately evaluating and enhancing the long-context capabilities of LLMs.

MCML Authors
Link to Stefanie Jegelka

Stefanie Jegelka

Prof. Dr.

Foundations of Deep Neural Networks


[4]
K. Gatmiry, N. Saunshi, S. J. Reddi, S. Jegelka and S. Kumar.
On the Role of Depth and Looping for In-Context Learning with Task Diversity.
Preprint (Oct. 2024). arXiv.
Abstract

The intriguing in-context learning (ICL) abilities of deep Transformer models have lately garnered significant attention. By studying in-context linear regression on unimodal Gaussian data, recent empirical and theoretical works have argued that ICL emerges from Transformers’ abilities to simulate learning algorithms like gradient descent. However, these works fail to capture the remarkable ability of Transformers to learn multiple tasks in context. To this end, we study in-context learning for linear regression with diverse tasks, characterized by data covariance matrices with condition numbers ranging from [1,κ], and highlight the importance of depth in this setting. More specifically, (a) we show theoretical lower bounds of log(κ) (or κ√) linear attention layers in the unrestricted (or restricted) attention setting and, (b) we show that multilayer Transformers can indeed solve such tasks with a number of layers that matches the lower bounds. However, we show that this expressivity of multilayer Transformer comes at the price of robustness. In particular, multilayer Transformers are not robust to even distributional shifts as small as O(e−L) in Wasserstein distance, where L is the depth of the network. We then demonstrate that Looped Transformers – a special class of multilayer Transformers with weight-sharing – not only exhibit similar expressive power but are also provably robust under mild assumptions. Besides out-of-distribution generalization, we also show that Looped Transformers are the only models that exhibit a monotonic behavior of loss with respect to depth.

MCML Authors
Link to Stefanie Jegelka

Stefanie Jegelka

Prof. Dr.

Foundations of Deep Neural Networks


[3]
T. Putterman, D. Lim, Y. Gelberg, S. Jegelka and H. Maron.
Learning on LoRAs: GL-Equivariant Processing of Low-Rank Weight Spaces for Large Finetuned Models.
Preprint (Oct. 2024). arXiv.
Abstract

Low-rank adaptations (LoRAs) have revolutionized the finetuning of large foundation models, enabling efficient adaptation even with limited computational resources. The resulting proliferation of LoRAs presents exciting opportunities for applying machine learning techniques that take these low-rank weights themselves as inputs. In this paper, we investigate the potential of Learning on LoRAs (LoL), a paradigm where LoRA weights serve as input to machine learning models. For instance, an LoL model that takes in LoRA weights as inputs could predict the performance of the finetuned model on downstream tasks, detect potentially harmful finetunes, or even generate novel model edits without traditional training methods. We first identify the inherent parameter symmetries of low rank decompositions of weights, which differ significantly from the parameter symmetries of standard neural networks. To efficiently process LoRA weights, we develop several symmetry-aware invariant or equivariant LoL models, using tools such as canonicalization, invariant featurization, and equivariant layers. We finetune thousands of text-to-image diffusion models and language models to collect datasets of LoRAs. In numerical experiments on these datasets, we show that our LoL architectures are capable of processing low rank weight decompositions to predict CLIP score, finetuning data attributes, finetuning data membership, and accuracy on downstream tasks.

MCML Authors
Link to Stefanie Jegelka

Stefanie Jegelka

Prof. Dr.

Foundations of Deep Neural Networks


[2]
M. Yau, E. Akyürek, J. Mao, J. B. Tenenbaum, S. Jegelka and J. Andreas.
Learning Linear Attention in Polynomial Time.
Preprint (Oct. 2024). arXiv.
Abstract

Previous research has explored the computational expressivity of Transformer models in simulating Boolean circuits or Turing machines. However, the learnability of these simulators from observational data has remained an open question. Our study addresses this gap by providing the first polynomial-time learnability results (specifically strong, agnostic PAC learning) for single-layer Transformers with linear attention. We show that linear attention may be viewed as a linear predictor in a suitably defined RKHS. As a consequence, the problem of learning any linear transformer may be converted into the problem of learning an ordinary linear predictor in an expanded feature space, and any such predictor may be converted back into a multiheaded linear transformer. Moving to generalization, we show how to efficiently identify training datasets for which every empirical risk minimizer is equivalent (up to trivial symmetries) to the linear Transformer that generated the data, thereby guaranteeing the learned model will correctly generalize across all inputs. Finally, we provide examples of computations expressible via linear attention and therefore polynomial-time learnable, including associative memories, finite automata, and a class of Universal Turing Machine (UTMs) with polynomially bounded computation histories. We empirically validate our theoretical findings on three tasks: learning random linear attention networks, key–value associations, and learning to execute finite automata. Our findings bridge a critical gap between theoretical expressivity and learnability of Transformers, and show that flexible and general models of computation are efficiently learnable.

MCML Authors
Link to Stefanie Jegelka

Stefanie Jegelka

Prof. Dr.

Foundations of Deep Neural Networks


[1]
Q. Zhang, Y. Wang, J. Cui, X. Pan, Q. Lei, S. Jegelka and Y. Wang.
Beyond Interpretability: The Gains of Feature Monosemanticity on Model Robustness.
Preprint (Oct. 2024). arXiv.
Abstract

Deep learning models often suffer from a lack of interpretability due to polysemanticity, where individual neurons are activated by multiple unrelated semantics, resulting in unclear attributions of model behavior. Recent advances in monosemanticity, where neurons correspond to consistent and distinct semantics, have significantly improved interpretability but are commonly believed to compromise accuracy. In this work, we challenge the prevailing belief of the accuracy-interpretability tradeoff, showing that monosemantic features not only enhance interpretability but also bring concrete gains in model performance. Across multiple robust learning scenarios-including input and label noise, few-shot learning, and out-of-domain generalization-our results show that models leveraging monosemantic features significantly outperform those relying on polysemantic features. Furthermore, we provide empirical and theoretical understandings on the robustness gains of feature monosemanticity. Our preliminary analysis suggests that monosemanticity, by promoting better separation of feature representations, leads to more robust decision boundaries. This diverse evidence highlights the generality of monosemanticity in improving model robustness. As a first step in this new direction, we embark on exploring the learning benefits of monosemanticity beyond interpretability, supporting the long-standing hypothesis of linking interpretability and robustness.

MCML Authors
Link to Stefanie Jegelka

Stefanie Jegelka

Prof. Dr.

Foundations of Deep Neural Networks