package org.infinispan.client.hotrod.impl;

import org.infinispan.client.hotrod.CacheTopologyInfo;
import org.infinispan.client.hotrod.configuration.Configuration;
import org.infinispan.client.hotrod.impl.consistenthash.ConsistentHash;
import org.infinispan.client.hotrod.impl.consistenthash.ConsistentHashFactory;
import org.infinispan.client.hotrod.impl.consistenthash.SegmentConsistentHash;
import org.infinispan.client.hotrod.impl.protocol.HotRodConstants;
import org.infinispan.client.hotrod.logging.Log;
import org.infinispan.client.hotrod.logging.LogFactory;
import org.infinispan.commons.equivalence.AnyEquivalence;
import org.infinispan.commons.equivalence.ByteArrayEquivalence;
import org.infinispan.commons.util.CollectionFactory;
import org.infinispan.commons.util.Immutables;

import java.net.SocketAddress;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * Maintains topology information about caches.
 *
 * @author gustavonalle
 */
public final class TopologyInfo {

   private static final Log log = LogFactory.getLog(TopologyInfo.class, Log.class);

   private Collection<SocketAddress> servers = new ArrayList<SocketAddress>();
   private Map<byte[], ConsistentHash> consistentHashes = CollectionFactory.makeMap(ByteArrayEquivalence.INSTANCE, AnyEquivalence.<ConsistentHash>getInstance());
   private Map<byte[], Integer> segmentsByCache = CollectionFactory.makeMap(ByteArrayEquivalence.INSTANCE, AnyEquivalence.<Integer>getInstance());
   private Map<byte[], AtomicInteger> topologyIds = CollectionFactory.makeMap(ByteArrayEquivalence.INSTANCE, AnyEquivalence.<AtomicInteger>getInstance());
   private final ConsistentHashFactory hashFactory = new ConsistentHashFactory();

   public TopologyInfo(AtomicInteger topologyId, Collection<SocketAddress> servers, Configuration configuration) {
      this.topologyIds.put(new byte[0], topologyId);
      this.servers = servers;
      hashFactory.init(configuration);
   }

   private Map<SocketAddress, Set<Integer>> getSegmentsByServer(byte[] cacheName) {
      ConsistentHash consistentHash = consistentHashes.get(cacheName);
      if (consistentHash != null) {
         return consistentHash.getSegmentsByServer();
      } else {
         Integer numSegments = segmentsByCache.get(cacheName);
         Set<Integer> segments = new HashSet<Integer>();
         if (numSegments != null) {
            for (int i = 0; i < numSegments; i++) {
               segments.add(i);
            }
         }
         Map<SocketAddress, Set<Integer>> segmentsPerServer = new HashMap<SocketAddress, Set<Integer>>();
         for (SocketAddress server : servers) {
            segmentsPerServer.put(server, segments);
         }
         return Immutables.immutableMapWrap(segmentsPerServer);
      }
   }

   public Collection<SocketAddress> getServers() {
      return servers;
   }

   public void updateTopology(Map<SocketAddress, Set<Integer>> servers2Hash, int numKeyOwners, short hashFunctionVersion, int hashSpace,
                              byte[] cacheName, AtomicInteger topologyId) {
      ConsistentHash hash = hashFactory.newConsistentHash(hashFunctionVersion);
      if (hash == null) {
         log.noHasHFunctionConfigured(hashFunctionVersion);
      } else {
         hash.init(servers2Hash, numKeyOwners, hashSpace);
      }
      consistentHashes.put(cacheName, hash);
      topologyIds.put(cacheName, topologyId);
   }

   public void updateTopology(SocketAddress[][] segmentOwners, int numSegments, short hashFunctionVersion,
                              byte[] cacheName, AtomicInteger topologyId) {
      if (hashFunctionVersion > 0) {
         SegmentConsistentHash hash = hashFactory.newConsistentHash(hashFunctionVersion);
         if (hash == null) {
            log.noHasHFunctionConfigured(hashFunctionVersion);
         } else {
            hash.init(segmentOwners, numSegments);
         }
         consistentHashes.put(cacheName, hash);
      }
      segmentsByCache.put(cacheName, numSegments);
      topologyIds.put(cacheName, topologyId);
   }

   public SocketAddress getHashAwareServer(Object key, byte[] cacheName) {
      SocketAddress server = null;
      if (isTopologyValid(cacheName)) {
         ConsistentHash consistentHash = consistentHashes.get(cacheName);
         if (consistentHash != null) {
            server = consistentHash.getServer(key);
            if (log.isTraceEnabled()) {
               log.tracef("Using consistent hash for determining the server: " + server);
            }
         }
      }
      return server;
   }

   public boolean isTopologyValid(byte[] cacheName) {
      Integer id = topologyIds.get(cacheName).get();
      Boolean valid = id != HotRodConstants.SWITCH_CLUSTER_TOPOLOGY;
      if (log.isTraceEnabled())
         log.tracef("Is topology id (%s) valid? %b", id, valid);

      return valid;
   }

   public void updateServers(Collection<SocketAddress> updatedServers) {
      servers = updatedServers;
   }

   public ConsistentHash getConsistentHash(byte[] cacheName) {
      return consistentHashes.get(cacheName);
   }

   public ConsistentHashFactory getConsistentHashFactory() {
      return hashFactory;
   }

   public AtomicInteger createTopologyId(byte[] cacheName, int topologyId) {
      AtomicInteger id = new AtomicInteger(topologyId);
      this.topologyIds.put(cacheName, id);
      return id;
   }

   public void setTopologyId(byte[] cacheName, int topologyId) {
      AtomicInteger id = this.topologyIds.get(cacheName);
      id.set(topologyId);
   }

   public void setAllTopologyIds(int newTopologyId) {
      for (AtomicInteger topologyId : topologyIds.values())
         topologyId.set(newTopologyId);
   }

   public int getTopologyId(byte[] cacheName) {
      return topologyIds.get(cacheName).get();
   }

   public CacheTopologyInfo getCacheTopologyInfo(byte[] cacheName) {
      return new CacheTopologyInfoImpl(getSegmentsByServer(cacheName), segmentsByCache.get(cacheName),
              topologyIds.get(cacheName).get());
   }

}
