A research team from Skoltech and the AIRI Institute have proposed a new approach to solving an extensive class of complex computational problems based on optimal transport, which are universally applicable in machine learning and mathematical modeling. The method will speed up the training of models from 3 to 10 times. The researchers will present their results in Vancouver at NeurIPS 2024, one of the world’s largest AI conferences.
Optimal transport methods are increasingly used today in the training of generative models for the synthesis of artificial data, for example, images or texts. Another significant application is the adaptation of models to data from new sources, which is especially important in medicine, where work is often associated with small and disparate samples. However, existing methods of solving optimal transport problems using neural networks face a number of problems, such as high learning instability and the need for complex intermediate transformations.
The key advantage of the method proposed by scientists is the introduction of explicit regularization — it was implemented on the JAX framework and called ENOT. This enabled a significant acceleration of calculations — from 3 to 10 times — and improved the target metrics of the model quality. Initially, experiments were conducted on two-dimensional data, and later the method was tested on image generation, style transfer, and reconstruction of three-dimensional objects, which confirmed its versatility.
“Despite the theoretical format of the article, the method is based on an intuitive idea — it is proposed to ‘pull together’ the generated to the expected. It seems to me that there are too many ‘black boxes’ in our area today. Of course, a tenfold acceleration is a weighty argument, but I think that the reviewers of NeurIPS liked our method precisely for its intuitiveness,” said Dmitry Dylov, an associate professor at the AI Center of Skoltech and director of the ‘AGI Med’ Laboratory at the AIRI Institute.
“Our method is the fastest and most accurate to date. As for practical application, the tasks of optimal transport are multidisciplinary, so it can be applied in a wide variety of fields. In particular, we used it to train imitation — when an expert shows certain actions, the agent tries to simulate behavior, and the system evaluates how similar the agent’s actions are to the actions of the expert. For example, a dance lesson when a teacher shows a movement and a student tries to repeat it,” explained Nazar Buzun, the head of the ‘Representation learning’ group at the AIRI Institute’s ‘AGI med’ Laboratory.
The method has already found a response in the scientific community: The article accepted at the conference appeared in Spotlight — a special track of the conference for papers specially noted by reviewers. In addition, the authors of one of the main solutions on the ott-jax topic quickly implemented the method into their library.