package org.infinispan.client.hotrod.impl.iteration;

import org.infinispan.client.hotrod.MetadataValue;
import org.infinispan.client.hotrod.RemoteCache;
import org.infinispan.client.hotrod.RemoteCacheManager;
import org.infinispan.client.hotrod.test.MultiHotRodServersTest;
import org.infinispan.commons.marshall.Marshaller;
import org.infinispan.commons.util.CloseableIterator;
import org.infinispan.distribution.ch.ConsistentHash;
import org.infinispan.filter.AbstractKeyValueFilterConverter;
import org.infinispan.filter.KeyValueFilterConverter;
import org.infinispan.filter.KeyValueFilterConverterFactory;
import org.infinispan.filter.ParamKeyValueFilterConverterFactory;
import org.infinispan.metadata.Metadata;
import org.infinispan.query.dsl.embedded.testdomain.hsearch.AccountHS;
import org.infinispan.server.hotrod.HotRodServer;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;

import java.io.Serializable;
import java.util.HashSet;
import java.util.Map.Entry;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.TimeUnit;

import static org.infinispan.client.hotrod.impl.iteration.BaseMultiServerRemoteIteratorTest.SubstringFilterFactory.DEFAULT_LENGTH;
import static org.infinispan.client.hotrod.impl.iteration.RemoteIteratorTestUtils.*;
import static org.testng.Assert.assertTrue;
import static org.testng.AssertJUnit.assertEquals;
import static org.testng.AssertJUnit.assertFalse;

/**
 * @author gustavonalle
 * @since 8.0
 */
@Test(groups = "functional")
public abstract class BaseMultiServerRemoteIteratorTest extends MultiHotRodServersTest {

   public static final int CACHE_SIZE = 20;

   @BeforeMethod
   public void clear() {
      for (RemoteCacheManager remoteCacheManager : clients) {
         remoteCacheManager.getCache().clear();
      }
   }

   @Test
   public void testBatchSizes() {
      int maximumBatchSize = 120;
      RemoteCache<Integer, AccountHS> cache = clients.get(0).getCache();

      populateCache(CACHE_SIZE, new Function<Integer, AccountHS>() {
         @Override
         public AccountHS apply(Integer integer) {
            return newAccount(integer);
         }
      }, cache);

      Set<Integer> expectedKeys = range(0, CACHE_SIZE);

      for (int batch = 1; batch < maximumBatchSize; batch += 10) {
         Set<Entry<Object, Object>> results = new HashSet<Entry<Object, Object>>(CACHE_SIZE);
         CloseableIterator<Entry<Object, Object>> iterator = cache.retrieveEntries(null, null, batch);
         while (iterator.hasNext()) {
            Entry<Object, Object> next = iterator.next();
            results.add(next);
         }
         iterator.close();
         assertEquals(CACHE_SIZE, results.size());
         assertEquals(expectedKeys, extractKeys(results));
      }
   }

   @Test
   public void testEmptyCache() {
      CloseableIterator<Entry<Object, Object>> iterator = null;
      try {
         iterator = client(0).getCache().retrieveEntries(null, null, 100);
         assertFalse(iterator.hasNext());
         assertFalse(iterator.hasNext());
      } finally {
         if (iterator != null) {
            iterator.close();
         }
      }
   }

   @Test
   public void testFilterBySegmentAndCustomFilter() {
      String toHexConverterName = "toHexConverter";
      for (HotRodServer s : servers) {
         s.addKeyValueFilterConverterFactory(toHexConverterName, new ToHexConverterFactory());
      }

      RemoteCache<Integer, Integer> numbersCache = clients.get(0).getCache();

      populateCache(CACHE_SIZE, new Function<Integer, Integer>() {
         @Override
         public Integer apply(Integer integer) {
            return integer;
         }
      }, numbersCache);

      Set<Integer> segments = setOf(15, 20, 25);
      Set<Entry<Object, Object>> entries = new HashSet<Entry<Object, Object>>();
      CloseableIterator<Entry<Object, Object>> iterator = null;
      try {
         iterator = numbersCache.retrieveEntries(toHexConverterName, segments, 10);
         while (iterator.hasNext()) {
            entries.add(iterator.next());
         }

      } finally {
         if (iterator != null) {
            iterator.close();
         }
      }

      Set<String> values = extractValues(entries);
      for (Integer key : getKeysFromSegments(segments)) {
         assertTrue(values.contains(Integer.toHexString(key)));
      }
   }

   @Test
   public void testFilterByCustomParamFilter() {
      String factoryName = "substringConverter";
      for (HotRodServer s : servers) {
         s.<String, String, String>addKeyValueFilterConverterFactory(factoryName, new SubstringFilterFactory());
      }
      final int filterParam = 12;

      RemoteCache<String, String> stringCache = clients.get(0).getCache();
      for (int idx : range(0, CACHE_SIZE - 1)) {
         stringCache.put(String.valueOf(idx), UUID.randomUUID().toString());
      }

      CloseableIterator<Entry<Object, Object>> iterator = stringCache.retrieveEntries(factoryName, new Object[]{filterParam}, null, 10);
      Set<Entry<Object, Object>> entries = extractEntries(iterator);
      Set<String> values = extractValues(entries);
      assertForAll(values, new org.infinispan.client.hotrod.impl.iteration.Condition<String>() {
         @Override
         public boolean test(String s) {
            return s.length() == filterParam;
         }
      });

      // Omitting param, filter should use default value
      entries = extractEntries(stringCache.retrieveEntries(factoryName, 10));
      values = extractValues(entries);
      assertForAll(values, new org.infinispan.client.hotrod.impl.iteration.Condition<String>() {
         @Override
         public boolean test(String s) {
            return s.length() == DEFAULT_LENGTH;
         }
      });
   }


   private Set<Entry<Object, Object>> extractEntries(CloseableIterator<Entry<Object, Object>> iterator) {
      Set<Entry<Object, Object>> entries = new HashSet<Entry<Object, Object>>();
      try {
         while (iterator.hasNext()) entries.add(iterator.next());
      } finally {
         if (iterator != null) {
            iterator.close();
         }
      }
      return entries;
   }


   @Test
   public void testFilterBySegment() {
      RemoteCache<Integer, AccountHS> cache = clients.get(0).getCache();
      populateCache(CACHE_SIZE, new Function<Integer, AccountHS>() {
         @Override
         public AccountHS apply(Integer integer) {
            return newAccount(integer);
         }
      }, cache);

      Set<Integer> filterBySegments = range(30, 40);

      Set<Entry<Object, Object>> entries = new HashSet<Entry<Object, Object>>();
      CloseableIterator<Entry<Object, Object>> iterator = null;
      try {
         iterator = cache.retrieveEntries(null, filterBySegments, 10);
         while (iterator.hasNext()) {
            entries.add(iterator.next());
         }
      } finally {
         if (iterator != null) {
            iterator.close();
         }
      }

      Marshaller marshaller = clients.iterator()
              .next()
              .getMarshaller();
      final ConsistentHash consistentHash = advancedCache(0).getDistributionManager()
              .getConsistentHash();

      assertKeysInSegment(entries, filterBySegments, marshaller, new Function<byte[], Integer>() {
         @Override
         public Integer apply(byte[] bytes) {
            return consistentHash.getSegment(bytes);
         }
      });
   }

   @Test
   public void testRetrieveMetadata() throws Exception {
      RemoteCache<Integer, AccountHS> cache = clients.get(0).getCache();
      cache.put(1, newAccount(1), 1, TimeUnit.DAYS);
      cache.put(2, newAccount(2), 2, TimeUnit.MINUTES, 30, TimeUnit.SECONDS);
      cache.put(3, newAccount(3));

      CloseableIterator<Entry<Object, MetadataValue<Object>>> iterator = null;
      try {
         iterator = cache.retrieveEntriesWithMetadata(null, 10);

         Entry<Object, MetadataValue<Object>> entry = iterator.next();
         if ((Integer) entry.getKey() == 1) {
            assertEquals(24 * 3600, entry.getValue().getLifespan());
            assertEquals(-1, entry.getValue().getMaxIdle());
         }
         if ((Integer) entry.getKey() == 2) {
            assertEquals(2 * 60, entry.getValue().getLifespan());
            assertEquals(30, entry.getValue().getMaxIdle());
         }
         if ((Integer) entry.getKey() == 3) {
            assertEquals(-1, entry.getValue().getLifespan());
            assertEquals(-1, entry.getValue().getMaxIdle());
         }
      } finally {
         if (iterator != null) {
            iterator.close();
         }
      }
   }

   static final class ToHexConverterFactory implements KeyValueFilterConverterFactory<Integer, Integer, String> {
      @Override
      public KeyValueFilterConverter<Integer, Integer, String> getFilterConverter() {
         return new HexFilterConverter();
      }

      static class HexFilterConverter extends AbstractKeyValueFilterConverter<Integer, Integer, String> implements Serializable {
         @Override
         public String filterAndConvert(Integer key, Integer value, Metadata metadata) {
            return Integer.toHexString(value);
         }
      }

   }

   static final class SubstringFilterFactory implements ParamKeyValueFilterConverterFactory<String, String, String> {

      public static final int DEFAULT_LENGTH = 20;

      @Override
      public KeyValueFilterConverter<String, String, String> getFilterConverter(Object[] params) {
         return new SubstringFilterConverter(params);
      }

      @Override
      public KeyValueFilterConverter getFilterConverter() {
         return new SubstringFilterConverter(null);
      }


      static class SubstringFilterConverter extends AbstractKeyValueFilterConverter<String, String, String> implements Serializable {
         private final int length;

         public SubstringFilterConverter(Object[] params) {
            if (params == null || params.length == 0) {
               length = DEFAULT_LENGTH;
            } else {
               length = (Integer) params[0];
            }
         }

         @Override
         public String filterAndConvert(String key, String value, Metadata metadata) {
            return value.substring(0, length);
         }
      }
   }

   private Set<Integer> getKeysFromSegments(Set<Integer> segments) {
      RemoteCacheManager remoteCacheManager = clients.get(0);
      RemoteCache<Integer, ?> cache = remoteCacheManager.getCache();
      Marshaller marshaller = clients.get(0)
              .getMarshaller();
      ConsistentHash hash = advancedCache(0).getDistributionManager()
              .getConsistentHash();
      Set<Integer> keys = cache.keySet();
      Set<Integer> result = new HashSet<Integer>();
      for (Integer key : keys) {
         if (segments.contains(hash.getSegment(toByteBuffer(key, marshaller)))) {
            result.add(key);
         }
      }
      return result;

   }

}
