spring cloud gateway LoadBalancer 自定义负载均衡

1,配置类


package com.tycloud.gateway.config;

import org.apache.commons.lang.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.cloud.client.ServiceInstance;
import org.springframework.cloud.client.loadbalancer.*;
import org.springframework.cloud.loadbalancer.annotation.LoadBalancerClient;
import org.springframework.cloud.loadbalancer.core.NoopServiceInstanceListSupplier;
import org.springframework.cloud.loadbalancer.core.ReactorLoadBalancer;
import org.springframework.cloud.loadbalancer.core.ReactorServiceInstanceLoadBalancer;
import org.springframework.cloud.loadbalancer.core.ServiceInstanceListSupplier;
import org.springframework.cloud.loadbalancer.support.LoadBalancerClientFactory;
import org.springframework.context.annotation.Bean;
import org.springframework.core.env.Environment;
import reactor.core.publisher.Mono;

import java.util.List;
import java.util.Random;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;


public class VersionLoadBalancerConfig {
    @Bean
    ReactorLoadBalancer versionGrayLoadBalancer(Environment environment,
                                                                 LoadBalancerClientFactory loadBalancerClientFactory) {
        String name = environment.getProperty(LoadBalancerClientFactory.PROPERTY_NAME);
        return new VersionGrayLoadBalancer(
                loadBalancerClientFactory.getLazyProvider(name, ServiceInstanceListSupplier.class), name);
    }

    public static class VersionGrayLoadBalancer implements ReactorServiceInstanceLoadBalancer {

        private static final Logger log = LoggerFactory.getLogger(VersionGrayLoadBalancer.class);

        private final ObjectProvider serviceInstanceListSupplierProvider;
        private final String serviceId;

        private final AtomicInteger position;

        public VersionGrayLoadBalancer(ObjectProvider serviceInstanceListSupplierProvider,
                                       String serviceId) {
            this(serviceInstanceListSupplierProvider, serviceId, new Random().nextInt(1000));
        }

        public VersionGrayLoadBalancer(ObjectProvider serviceInstanceListSupplierProvider,
                                       String serviceId, int seedPosition) {
            this.serviceId = serviceId;
            this.serviceInstanceListSupplierProvider = serviceInstanceListSupplierProvider;
            this.position = new AtomicInteger(seedPosition);
        }

        @Override
        public Mono> choose(Request request) {
            ServiceInstanceListSupplier supplier =
                    this.serviceInstanceListSupplierProvider.getIfAvailable(NoopServiceInstanceListSupplier::new);
            return supplier.get(request).next()
                    .map(serviceInstances -> processInstanceResponse(serviceInstances, request));
        }


        private Response processInstanceResponse(List instances, Request request) {
            if (instances.isEmpty()) {
                if (log.isWarnEnabled()) {
                    log.warn("No servers available for service: " + this.serviceId);
                }
                return new EmptyResponse();
            }
            DefaultRequestContext requestContext = (DefaultRequestContext) request.getContext();
            RequestData clientRequest = (RequestData) requestContext.getClientRequest();
            List ips = clientRequest.getHeaders().get("X-Real-IP");
            String host = ips != null ? ips.get(0) : null;
            log.info("request header host:{}", host);
            if (StringUtils.isEmpty(host)){
                return processRibbonInstanceResponse(instances);
            }
            // filter service instances
            List serviceInstances = instances.stream()
                    .filter(instance -> instance.getUri().getHost().equals(host))
                    .collect(Collectors.toList());
            log.info("get ServiceInstance size:{}", serviceInstances.size());
            List availableInstances = serviceInstances.size() > 0 ? serviceInstances : instances;
            return processRibbonInstanceResponse(availableInstances);
        }


        private Response processRibbonInstanceResponse(List instances) {
            int pos = Math.abs(this.position.incrementAndGet());
            ServiceInstance instance = instances.get(pos % instances.size());
            return new DefaultResponse(instance);
        }
    }
}

2,启动类

单个服务:加@LoadBalancerClient(name = "PERMISSION", configuration = VersionLoadBalancerConfig.class),如果所有的服务都需要配置用@LoadBalancerClients注解


@SpringBootApplication
@EnableDiscoveryClient
@LoadBalancerClient(name = "PERMISSION", configuration = VersionLoadBalancerConfig.class)
public class GatewayApplication {
    public static void main(String[] args) {
        SpringApplication.run(GatewayApplication.class, args);
    }

    @Bean
    @Primary
    public RateLimiterGatewayFilterFactory rateLimiterGatewayFilterFactory(RedisRateLimiter redisRateLimiter, HostAddrKeyResolver resolver, ObjectMapper objectMapper) {
        return new RateLimiterGatewayFilterFactory(redisRateLimiter, resolver, objectMapper);
    }

    @Bean
    public HostAddrKeyResolver hostAddrKeyResolver() {
        return new HostAddrKeyResolver();
    }

}

3,实现原理

在spring-cloud-commons模块里的org.springframework.cloud.client.loadbalancer.reactive包下有一个实现了BeanPostProcessor的LoadBalancerWebClientBuilderBeanPostProcessor类,源码如下所示。


public class LoadBalancerWebClientBuilderBeanPostProcessor implements BeanPostProcessor {
...

    @Override
    public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException {
        if (bean instanceof WebClient.Builder) {
            if (context.findAnnotationOnBean(beanName, LoadBalanced.class) == null) {
                return bean;
            }
            ((WebClient.Builder) bean).filter(exchangeFilterFunction);
        }
        return bean;
    }
}

实现了BeanPostProcessor接口的每个类都会被纳入Spring的生命周期,其postProcessBeforeInitialization()方法在bean初始化时执行。可以看到,源码先判断此bean是否是WebClient.Builder且被@LoadBalanced注解,如果是则设置一个filter:DeferringLoadBalancerExchangeFilterFunction, 其内部有一ExchangeFilterFunction类型的代理,我们就是通过从外部将我们的负载均衡逻辑设置给这个代理的。
具体来说是org.springframework.cloud.client.loadbalancer.reactive.ReactorLoadBalancerExchangeFilterFunction这个类被设置给了WebClient的filter方法。而ReactorLoadBalancerExchangeFilterFunction又是通过其字段ReactiveLoadBalancer.Factory loadBalancerFactory将外面的配置传递进来的。
至此问题准换为构建一个ReactiveLoadBalancer.Factory实例。

上面代码都是位于spring-cloud-commons抽象层中的,接下来的代码就位于我们引入的具体实现spring-cloud-loadbalancer中了。因为引入的是starter,我们首先瞄准LoadBalancerAutoConfiguration自动配置类,里面就有一个提供ReactiveLoadBalancer.FactoryBean的方法。


@ConditionalOnMissingBean
@Bean
public LoadBalancerClientFactory loadBalancerClientFactory() {
    LoadBalancerClientFactory clientFactory = new LoadBalancerClientFactory();
    clientFactory.setConfigurations(this.configurations.getIfAvailable(Collections::emptyList));
    return clientFactory;
}

其中LoadBalancerClientFactory实现了ReactiveLoadBalancer.Factory接口。其构造函数如下


public LoadBalancerClientFactory() {
    super(LoadBalancerClientConfiguration.class, NAMESPACE, PROPERTY_NAME);
}

将默认配置设置为LoadBalancerClientConfiguration。
至此,关于默认负载均衡的算法配置已经被锁定到了LoadBalancerClientConfiguration类。
默认算法
查看LoadBalancerClientConfiguration类,发现了设置默认负载均衡算法的代码。


@Bean
@ConditionalOnMissingBean
public ReactorLoadBalancer reactorServiceInstanceLoadBalancer(Environment environment,
        LoadBalancerClientFactory loadBalancerClientFactory) {
    String name = environment.getProperty(LoadBalancerClientFactory.PROPERTY_NAME);
    return new RoundRobinLoadBalancer(
            loadBalancerClientFactory.getLazyProvider(name, ServiceInstanceListSupplier.class), name);
}

可以看到,默认使用RoundRobinLoadBalancer。
如何切换负载均衡算法
通过上面的原理分析,切换负载均衡算法就变得很简单了。
第一步:写一个自定义的负载均衡配置类
参考LoadBalancerClientConfiguration提供一个ReactorLoadBalancerbean即可,下面是我写的自定义配置类,使用了RandomLoadBalancer来实现随机负载均衡,其中RandomLoadBalancer也是spring-cloud-loadbalancer提供的,如果我们要写自己的LoadBalancer可以参考它。


public class CustomLoadBalancerConfiguration {

    @Bean
    public ReactorLoadBalancer reactorServiceInstanceLoadBalancer(Environment environment,
                                                                                   LoadBalancerClientFactory loadBalancerClientFactory) {
        String name = environment.getProperty(LoadBalancerClientFactory.PROPERTY_NAME);
        return new RandomLoadBalancer(
                loadBalancerClientFactory.getLazyProvider(name, ServiceInstanceListSupplier.class), name);
    }
}

特别注意此配置类不要加@Configuration
如果要提供自定义的ServiceInstanceListSupplier就在此类中加入自定义的bean即可,如下所示


@Bean
public ServiceInstanceListSupplier myDiscoveryClientServiceInstanceListSupplier(
        ConfigurableApplicationContext context) {
    return ServiceInstanceListSupplier.builder().withDiscoveryClient().build(context);
}

第二步:将第一步的配置类设置给我们的负载均衡配置类里


@Configuration
@LoadBalancerClient(value = "order-service",configuration = CustomLoadBalancerConfiguration.class)
public class LoadBalanceConfiguration {

    @LoadBalanced
    @Bean
    public WebClient.Builder loadBalancedWebClientBuilder(){
        return WebClient.builder();
    }
}