package org.infinispan.client.hotrod.impl.iteration;

import org.infinispan.client.hotrod.impl.consistenthash.SegmentConsistentHash;
import org.infinispan.client.hotrod.logging.Log;
import org.infinispan.client.hotrod.logging.LogFactory;
import org.infinispan.commons.equivalence.ByteArrayEquivalence;
import org.infinispan.commons.util.BitSetUtils;
import org.infinispan.commons.util.CollectionFactory;
import org.infinispan.commons.util.Util;

import java.util.BitSet;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReferenceArray;

/**
 * @author gustavonalle
 * @since 8.0
 */
class SegmentKeyTracker implements KeyTracker {

   private static final Log log = LogFactory.getLog(SegmentKeyTracker.class);

   private final AtomicReferenceArray<Set<byte[]>> keysPerSegment;
   private final SegmentConsistentHash segmentConsistentHash;

   public SegmentKeyTracker(SegmentConsistentHash segmentConsistentHash, Set<Integer> segments) {
      int numSegments = segmentConsistentHash.getNumSegments();
      keysPerSegment = new AtomicReferenceArray<Set<byte[]>>(numSegments);
      if (log.isDebugEnabled()) log.debugf("Created SegmentKeyTracker with %d segments", numSegments);
      this.segmentConsistentHash = segmentConsistentHash;
      if(segments == null) {
         for (int i = 0; i < segmentConsistentHash.getNumSegments(); i++) {
            keysPerSegment.set(i, CollectionFactory.makeSet(ByteArrayEquivalence.INSTANCE));
         }
      } else {
         for (Integer segment: segments) {
            keysPerSegment.set(segment, CollectionFactory.makeSet(ByteArrayEquivalence.INSTANCE));
         }
      }
   }

   public boolean track(byte[] key) {
      int segment = segmentConsistentHash.getSegment(key);
      boolean result = keysPerSegment.get(segment).add(key);
      if (log.isTraceEnabled())
         log.trackingSegmentKey(Util.printArray(key), segment, !result);
      return result;
   }

   public Set<Integer> missedSegments() {
      int length = keysPerSegment.length();
      if (length == 0) return null;
      Set<Integer> missed = new HashSet<Integer>(length);
      for (int i = 0; i < keysPerSegment.length(); i++) {
         if (keysPerSegment.get(i) != null) {
            missed.add(i);
         }
      }
      return missed;
   }

   public void segmentsFinished(byte[] finishedSegments) {
      if (finishedSegments != null) {
         BitSet bitSet = BitSetUtils.fromByteArray(finishedSegments);
         if (log.isDebugEnabled()) log.debugf("Removing completed segments %s", bitSet);
         for (int i = bitSet.nextSetBit(0); i >= 0; i = bitSet.nextSetBit(i + 1)) {
            keysPerSegment.set(i, null);
         }
      }
   }
}
