Skip to content

Commit

Permalink
Support advanced generics redeclarations in RepositoryFactoryBeanSupp…
Browse files Browse the repository at this point in the history
…ort extensions.

Spring Data modules might override, and, by that, fix some of the generic type parameters exposed by RepositoryFactoryBeanSupport. We now more thoroughly walk through them to consider the ones expanded already and automatically expand the remaining ones with either the types found on the user repository interface or the unresolved type variable.

Ticket: GH-3074.
  • Loading branch information
odrotbohm committed Apr 9, 2024
1 parent dd081d4 commit 88011e6
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import org.springframework.core.metrics.StartupStep;
import org.springframework.data.repository.core.RepositoryMetadata;
import org.springframework.data.repository.core.support.AbstractRepositoryMetadata;
import org.springframework.data.repository.core.support.RepositoryFactoryBeanSupport;
import org.springframework.data.repository.core.support.RepositoryFactorySupport;
import org.springframework.data.util.ReflectionUtils;
import org.springframework.lang.Nullable;
Expand Down Expand Up @@ -339,22 +340,30 @@ private ResolvableType getRepositoryInterface(RepositoryConfiguration<?> configu
return null;
}

TypeVariable<?>[] variables = factoryBean.getTypeParameters();
int numberOfGenerics = variables.length;
RepositoryMetadata metadata = AbstractRepositoryMetadata.getMetadata(repositoryInterface);
List<Class<?>> types = List.of(repositoryInterface, metadata.getDomainType(), metadata.getIdType());

ResolvableType[] generics = new ResolvableType[numberOfGenerics];
generics[0] = ResolvableType.forClass(repositoryInterface);
generics[1] = ResolvableType.forClass(metadata.getDomainType());
generics[2] = ResolvableType.forClass(metadata.getIdType());
ResolvableType[] declaredGenerics = ResolvableType.forClass(factoryBean).getGenerics();
ResolvableType[] parentGenerics = ResolvableType.forClass(RepositoryFactoryBeanSupport.class, factoryBean)
.getGenerics();
List<ResolvableType> resolvedGenerics = new ArrayList<ResolvableType>(factoryBean.getTypeParameters().length);

if (numberOfGenerics > 3) {
for (int i = 3; i < numberOfGenerics; i++) {
generics[i] = ResolvableType.forType(variables[0]);
for (int i = 0; i < parentGenerics.length; i++) {

ResolvableType parameter = parentGenerics[i];

if (parameter.getType() instanceof TypeVariable<?>) {
resolvedGenerics.add(i < types.size() ? ResolvableType.forClass(types.get(i)) : parameter);
}
}

if (resolvedGenerics.size() < declaredGenerics.length) {
for (int j = parentGenerics.length; j < declaredGenerics.length; j++) {
resolvedGenerics.add(declaredGenerics[j]);
}
}

return ResolvableType.forClassWithGenerics(factoryBean, generics);
return ResolvableType.forClassWithGenerics(factoryBean, resolvedGenerics.toArray(ResolvableType[]::new));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@

import static org.assertj.core.api.Assertions.*;

import java.lang.reflect.TypeVariable;
import java.util.List;
import java.util.Optional;
import java.util.UUID;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mockito;
Expand All @@ -26,24 +31,30 @@
import org.springframework.aop.framework.Advised;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.beans.factory.ListableBeanFactory;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.parsing.BeanComponentDefinition;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.context.annotation.AnnotationBeanNameGenerator;
import org.springframework.context.annotation.AnnotationConfigApplicationContext;
import org.springframework.context.annotation.ComponentScan;
import org.springframework.context.annotation.ComponentScan.Filter;
import org.springframework.context.annotation.FilterType;
import org.springframework.context.support.GenericApplicationContext;
import org.springframework.core.ResolvableType;
import org.springframework.core.env.StandardEnvironment;
import org.springframework.core.metrics.ApplicationStartup;
import org.springframework.core.type.AnnotationMetadata;
import org.springframework.core.type.StandardAnnotationMetadata;
import org.springframework.data.mapping.Person;
import org.springframework.data.repository.Repository;
import org.springframework.data.repository.config.RepositoryConfigurationDelegate.LazyRepositoryInjectionPointResolver;
import org.springframework.data.repository.config.annotated.MyAnnotatedRepository;
import org.springframework.data.repository.config.annotated.MyAnnotatedRepositoryImpl;
import org.springframework.data.repository.config.annotated.MyFragmentImpl;
import org.springframework.data.repository.config.excluded.MyOtherRepositoryImpl;
import org.springframework.data.repository.config.stereotype.MyStereotypeRepository;
import org.springframework.data.repository.core.support.DummyRepositoryFactoryBean;
import org.springframework.data.repository.core.support.RepositoryFactoryBeanSupport;
import org.springframework.data.repository.sample.AddressRepository;
import org.springframework.data.repository.sample.AddressRepositoryClient;
import org.springframework.data.repository.sample.ProductRepository;
Expand Down Expand Up @@ -223,6 +234,38 @@ void registersAotPostProcessorForDifferentConfigurations() {
assertThat(context.getBeanNamesForType(RepositoryRegistrationAotProcessor.class)).hasSize(2);
}

@Test // GH-3074
void registersGenericsForIdConstrainingRepositoryFactoryBean() {

ResolvableType it = registerBeanDefinition(IdConstrainingRepositoryFactoryBean.class);

assertThat(it.getGenerics()).hasSize(2);
assertThat(it.getGeneric(0).resolve()).isEqualTo(MyAnnotatedRepository.class);
assertThat(it.getGeneric(1).resolve()).isEqualTo(Person.class);
}

@Test // GH-3074
void registersGenericsForDomainTypeConstrainingRepositoryFactoryBean() {

ResolvableType it = registerBeanDefinition(DomainTypeConstrainingRepositoryFactoryBean.class);

assertThat(it.getGenerics()).hasSize(2);
assertThat(it.getGeneric(0).resolve()).isEqualTo(MyAnnotatedRepository.class);
assertThat(it.getGeneric(1).resolve()).isEqualTo(String.class);
}

@Test // GH-3074
void registersGenericsForAdditionalGenericsRepositoryFactoryBean() {

ResolvableType it = registerBeanDefinition(AdditionalGenericsRepositoryFactoryBean.class);

assertThat(it.getGenerics()).hasSize(4);
assertThat(it.getGeneric(0).resolve()).isEqualTo(MyAnnotatedRepository.class);
assertThat(it.getGeneric(1).resolve()).isEqualTo(Person.class);
assertThat(it.getGeneric(2).resolve()).isEqualTo(String.class);
assertThat(it.getGeneric(3).getType()).isInstanceOf(TypeVariable.class);
}

private static ListableBeanFactory assertLazyRepositoryBeanSetup(Class<?> configClass) {

var context = new AnnotationConfigApplicationContext(configClass);
Expand Down Expand Up @@ -279,4 +322,56 @@ protected String getModulePrefix() {
return "commons";
}
}

private ResolvableType registerBeanDefinition(Class<?> repositoryFactoryType) {

AnnotationMetadata metadata = AnnotationMetadata.introspect(AnnotatedBeanNamesConfig.class);
AnnotationConfigApplicationContext context = new AnnotationConfigApplicationContext();

RepositoryConfigurationSource source = new AnnotationRepositoryConfigurationSource(metadata,
EnableRepositories.class, context, context.getEnvironment(),
context.getDefaultListableBeanFactory(), new AnnotationBeanNameGenerator()) {

@Override
public Optional<String> getRepositoryFactoryBeanClassName() {
return Optional.of(repositoryFactoryType.getName());
}
};

RepositoryConfigurationDelegate delegate = new RepositoryConfigurationDelegate(source, context,
context.getEnvironment());

List<BeanComponentDefinition> repositories = delegate.registerRepositoriesIn(context, extension);

assertThat(repositories).hasSize(1).element(0)
.extracting(BeanComponentDefinition::getBeanDefinition)
.extracting(BeanDefinition::getResolvableType)
.isNotNull();

return repositories.get(0).getBeanDefinition().getResolvableType();
}

static abstract class IdConstrainingRepositoryFactoryBean<T extends Repository<S, UUID>, S>
extends RepositoryFactoryBeanSupport<T, S, UUID> {

protected IdConstrainingRepositoryFactoryBean(Class<? extends T> repositoryInterface) {
super(repositoryInterface);
}
}

static abstract class DomainTypeConstrainingRepositoryFactoryBean<T extends Repository<Person, ID>, ID>
extends RepositoryFactoryBeanSupport<T, Person, ID> {

protected DomainTypeConstrainingRepositoryFactoryBean(Class<? extends T> repositoryInterface) {
super(repositoryInterface);
}
}

static abstract class AdditionalGenericsRepositoryFactoryBean<T extends Repository<S, ID>, S, ID, R>
extends RepositoryFactoryBeanSupport<T, S, ID> {

protected AdditionalGenericsRepositoryFactoryBean(Class<? extends T> repositoryInterface) {
super(repositoryInterface);
}
}
}

0 comments on commit 88011e6

Please sign in to comment.