// Copyright (C) 2015, 2022  HashiCorp, Inc.
// Copyright (C) 2023  Luke Shumaker <lukeshu@lukeshu.com>
//
// SPDX-License-Identifier: MPL-2.0
//
// Based on https://github.com/hashicorp/golang-lru/blob/efb1d5b30f66db326f4d8e27b3a5ad04f5e02ca3/arc_test.go

package containers

import (
	"bytes"
	"context"
	"crypto/rand"
	"fmt"
	"math/big"
	"sort"
	"testing"

	"github.com/datawire/dlib/derror"
	"github.com/datawire/dlib/dlog"
	"github.com/stretchr/testify/require"
)

// Add runtime validity checks /////////////////////////////////////////////////

func (c *arc[K, V]) logf(format string, a ...any) {
	c.t.Helper()
	c.t.Logf("%[1]T(%[1]p): %s (b1:%v t1:%v p1:%v / p1:%v, t2:%v b2:%v)",
		c,
		fmt.Sprintf(format, a...),
		c.recentGhost.Len,
		c.recentLive.Len,
		c.recentPinned.Len,
		c.frequentPinned.Len,
		c.frequentLive.Len,
		c.frequentGhost.Len)
}

func (c *arc[K, V]) fatalf(format string, a ...any) {
	c.logf(format, a...)
	c.t.FailNow()
}

func (c *arc[K, V]) check() {
	if c.noCheck {
		return
	}
	c.t.Helper()

	// Do the slow parts for 1/32 of all calls.
	fullCheck := getRand(c.t, 32) == 0

	// Check that the lists are in-sync with the maps.
	if fullCheck {
		liveEntries := make(map[*LinkedListEntry[arcLiveEntry[K, V]]]int, len(c.liveByName))
		for _, list := range c.liveLists {
			for entry := list.Oldest; entry != nil; entry = entry.Newer {
				liveEntries[entry]++
			}
		}
		for _, entry := range c.liveByName {
			liveEntries[entry]--
			if liveEntries[entry] == 0 {
				delete(liveEntries, entry)
			}
		}
		require.Len(c.t, liveEntries, 0)

		ghostEntries := make(map[*LinkedListEntry[arcGhostEntry[K]]]int, len(c.ghostByName))
		for _, list := range c.ghostLists {
			for entry := list.Oldest; entry != nil; entry = entry.Newer {
				ghostEntries[entry]++
			}
		}
		for _, entry := range c.ghostByName {
			ghostEntries[entry]--
			if ghostEntries[entry] == 0 {
				delete(ghostEntries, entry)
			}
		}
		require.Len(c.t, ghostEntries, 0)
	}

	// Check the invariants.

	// from DBL(2c):
	//
	//    • 0 ≤ |L₁|+|L₂| ≤ 2c
	if fullLen := c.recentPinned.Len + c.recentLive.Len + c.recentGhost.Len + c.frequentPinned.Len + c.frequentLive.Len + c.frequentGhost.Len; fullLen < 0 || fullLen > 2*c.cap {
		c.fatalf("! ( 0 <= fullLen:%v <= 2*cap:%v )", fullLen, c.cap)
	}
	//    • 0 ≤ |L₁| ≤ c
	if recentLen := c.recentPinned.Len + c.recentLive.Len + c.recentGhost.Len; recentLen < 0 || recentLen > c.cap {
		c.fatalf("! ( 0 <= recentLen:%v <= cap:%v )", recentLen, c.cap)
	}
	//    • 0 ≤ |L₂| ≤ 2c
	if frequentLen := c.frequentPinned.Len + c.frequentLive.Len + c.frequentGhost.Len; frequentLen < 0 || frequentLen > 2*c.cap {
		c.fatalf("! ( 0 <= frequentLen:%v <= 2*cap:%v )", frequentLen, c.cap)
	}
	//
	// from Π(c):
	//
	//    • A.1: The lists T₁, B₁, T₂, and B₂ are all mutually
	//      disjoint.
	if fullCheck {
		keys := make(map[K]int, len(c.liveByName)+len(c.ghostByName))
		for _, list := range c.liveLists {
			for entry := list.Oldest; entry != nil; entry = entry.Newer {
				keys[entry.Value.key]++
			}
		}
		for _, list := range c.ghostLists {
			for entry := list.Oldest; entry != nil; entry = entry.Newer {
				keys[entry.Value.key]++
			}
		}
		for key, cnt := range keys {
			if cnt > 1 {
				listNames := make([]string, 0, cnt)
				for listName, list := range c.liveLists {
					for entry := list.Oldest; entry != nil; entry = entry.Newer {
						if entry.Value.key == key {
							listNames = append(listNames, listName)
						}
					}
				}
				for listName, list := range c.ghostLists {
					for entry := list.Oldest; entry != nil; entry = entry.Newer {
						if entry.Value.key == key {
							listNames = append(listNames, listName)
						}
					}
				}
				sort.Strings(listNames)
				c.fatalf("dup key: %v is in %v", key, listNames)
			}
		}
	}
	//    • (not true) A.2: If |L₁|+|L₂| < c, then both B₁ and B₂ are
	//      empty.  But supporting "delete" invalidates this!
	//    • (not true) A.3: If |L₁|+|L₂| ≥ c, then |T₁|+|T₂| = c.  But
	//      supporting "delete" invalidates this!
	//    • A.4(a): Either (T₁ or B₁ is empty), or (the LRU page in T₁
	//      is more recent than the MRU page in B₁).
	//    • A.4(b): Either (T₂ or B₂ is empty), or (the LRU page in T₂
	//      is more recent than the MRU page in B₂).
	//    • A.5: |T₁∪T₂| is the set of pages that would be maintained
	//      by the cache policy π(c).
	//
	// from FRC(p, c):
	//
	//    • 0 ≤ p ≤ c
	if c.recentLiveTarget < 0 || c.recentLiveTarget > c.cap {
		c.fatalf("! ( 0 <= p:%v <= cap:%v )", c.recentLiveTarget, c.cap)
	}
}

// Compatibility layer for hashicorp/golang-lru ////////////////////////////////

type lenFunc func() int

func (fn lenFunc) Len() int { return fn() }

type arc[K comparable, V any] struct {
	*arCache[K, V]
	ctx context.Context //nolint:containedctx // have no choice to keep the hashicorp-compatible API
	t   testing.TB

	t1, t2, b1, b2 lenFunc

	// For speeding up .check()
	noCheck    bool
	liveLists  map[string]*LinkedList[arcLiveEntry[K, V]]
	ghostLists map[string]*LinkedList[arcGhostEntry[K]]
}

func NewARC[K comparable, V any](t testing.TB, size int) (*arc[K, V], error) {
	src := SourceFunc[K, V](func(context.Context, K, *V) {})
	_, isBench := t.(*testing.B)
	ret := &arc[K, V]{
		ctx:     dlog.NewTestContext(t, true),
		t:       t,
		noCheck: isBench,
	}
	ret.init(size, src)
	return ret, nil
}

func (c *arc[K, V]) init(size int, src Source[K, V]) {
	c.arCache = NewARCache[K, V](size, src).(*arCache[K, V])
	c.t1 = lenFunc(func() int { return c.arCache.recentLive.Len })
	c.t2 = lenFunc(func() int { return c.arCache.frequentLive.Len })
	c.b1 = lenFunc(func() int { return c.arCache.recentGhost.Len })
	c.b2 = lenFunc(func() int { return c.arCache.frequentGhost.Len })

	c.liveLists = map[string]*LinkedList[arcLiveEntry[K, V]]{
		"p1": &c.recentPinned,
		"t1": &c.recentLive,
		"p2": &c.frequentPinned,
		"t2": &c.frequentLive,
	}
	c.ghostLists = map[string]*LinkedList[arcGhostEntry[K]]{
		"b1": &c.recentGhost,
		"b2": &c.frequentGhost,
	}
}

// non-mutators

func (c *arc[K, V]) p() int            { return c.recentLiveTarget }
func (c *arc[K, V]) Len() int          { return len(c.liveByName) }
func (c *arc[K, V]) Contains(k K) bool { return c.liveByName[k] != nil }

func (c *arc[K, V]) Peek(k K) (V, bool) {
	entry := c.liveByName[k]
	if entry == nil {
		var zero V
		return zero, false
	}
	return entry.Value.val, true
}

func (c *arc[K, V]) Keys() []K {
	ret := make([]K, 0, len(c.liveByName))
	for entry := c.recentLive.Oldest; entry != nil; entry = entry.Newer {
		ret = append(ret, entry.Value.key)
	}
	for entry := c.frequentLive.Oldest; entry != nil; entry = entry.Newer {
		ret = append(ret, entry.Value.key)
	}
	return ret
}

// mutators

func (c *arc[K, V]) Remove(k K) {
	defer c.check()
	c.Delete(k)
}

func (c *arc[K, V]) Purge() {
	defer c.check()
	c.init(c.cap, c.src)
}

func (c *arc[K, V]) Get(k K) (V, bool) {
	defer c.check()
	if !c.Contains(k) {
		var zero V
		return zero, false
	}
	val := *c.Acquire(c.ctx, k)
	c.Release(k)
	return val, true
}

func (c *arc[K, V]) Add(k K, v V) {
	defer c.check()
	ptr := c.Acquire(c.ctx, k)
	*ptr = v
	c.Release(k)
}

// Tests from hashicorp/golang-lru /////////////////////////////////////////////

func getRand(tb testing.TB, limit int64) int64 {
	out, err := rand.Int(rand.Reader, big.NewInt(limit))
	if err != nil {
		tb.Fatal(err)
	}
	return out.Int64()
}

func BenchmarkARC_Rand(b *testing.B) {
	l, err := NewARC[int64, int64](b, 8192)
	if err != nil {
		b.Fatalf("err: %v", err)
	}

	trace := make([]int64, b.N*2)
	for i := 0; i < b.N*2; i++ {
		trace[i] = getRand(b, 32768)
	}

	b.ResetTimer()

	var hit, miss int
	for i := 0; i < 2*b.N; i++ {
		if i%2 == 0 {
			l.Add(trace[i], trace[i])
		} else {
			if _, ok := l.Get(trace[i]); ok {
				hit++
			} else {
				miss++
			}
		}
	}
	b.Logf("hit: %d miss: %d ratio: %f", hit, miss, float64(hit)/float64(miss))
}

func BenchmarkARC_Freq(b *testing.B) {
	l, err := NewARC[int64, int64](b, 8192)
	if err != nil {
		b.Fatalf("err: %v", err)
	}

	trace := make([]int64, b.N*2)
	for i := 0; i < b.N*2; i++ {
		if i%2 == 0 {
			trace[i] = getRand(b, 16384)
		} else {
			trace[i] = getRand(b, 32768)
		}
	}

	b.ResetTimer()

	for i := 0; i < b.N; i++ {
		l.Add(trace[i], trace[i])
	}
	var hit, miss int
	for i := 0; i < b.N; i++ {
		if _, ok := l.Get(trace[i]); ok {
			hit++
		} else {
			miss++
		}
	}
	b.Logf("hit: %d miss: %d ratio: %f", hit, miss, float64(hit)/float64(miss))
}

type arcOp struct {
	Op  uint8  // [0,3)
	Key uint16 // [0,512)
}

func (op *arcOp) UnmarshalBinary(dat []byte) (int, error) {
	*op = arcOp{
		Op:  (dat[0] >> 6) % 3,
		Key: uint16(dat[0]&0b1)<<8 | uint16(dat[1]),
	}
	return 2, nil
}

func (op arcOp) MarshalBinary() ([]byte, error) {
	return []byte{
		(op.Op << 6) | byte(op.Key>>8),
		byte(op.Key),
	}, nil
}

type arcOps []arcOp

func (ops *arcOps) UnmarshalBinary(dat []byte) (int, error) {
	*ops = make(arcOps, len(dat)/2)
	for i := 0; i < len(dat)/2; i++ {
		_, _ = (*ops)[i].UnmarshalBinary(dat[i*2:])
	}
	return len(*ops) * 2, nil
}

func (ops arcOps) MarshalBinary() ([]byte, error) {
	dat := make([]byte, 0, len(ops)*2)
	for _, op := range ops {
		_dat, _ := op.MarshalBinary()
		dat = append(dat, _dat...)
	}
	return dat, nil
}

func FuzzARC(f *testing.F) {
	n := 200000
	seed := make([]byte, n*2)
	_, err := rand.Read(seed)
	require.NoError(f, err)
	f.Add(seed)

	f.Fuzz(func(t *testing.T, dat []byte) {
		var ops arcOps
		_, _ = ops.UnmarshalBinary(dat)
		defer func() {
			if err := derror.PanicToError(recover()); err != nil {
				t.Errorf("%+v", err)
			}
			if t.Failed() && bytes.Equal(dat, seed) {
				SaveFuzz(f, dat)
			}
		}()
		testARC_RandomOps(t, ops)
	})
}

func testARC_RandomOps(t *testing.T, ops []arcOp) {
	size := 128
	l, err := NewARC[int64, int64](t, 128)
	if err != nil {
		t.Fatalf("err: %v", err)
	}

	for _, op := range ops {
		key := int64(op.Key)
		r := op.Op
		switch r % 3 {
		case 0:
			l.Add(key, key)
		case 1:
			l.Get(key)
		case 2:
			l.Remove(key)
		}

		if l.t1.Len()+l.t2.Len() > size {
			t.Fatalf("bad: t1: %d t2: %d b1: %d b2: %d p: %d",
				l.t1.Len(), l.t2.Len(), l.b1.Len(), l.b2.Len(), l.p())
		}
		if l.b1.Len()+l.b2.Len() > size {
			t.Fatalf("bad: t1: %d t2: %d b1: %d b2: %d p: %d",
				l.t1.Len(), l.t2.Len(), l.b1.Len(), l.b2.Len(), l.p())
		}
	}
}

func TestARC_Get_RecentToFrequent(t *testing.T) {
	t.Parallel()
	l, err := NewARC[int, int](t, 128)
	if err != nil {
		t.Fatalf("err: %v", err)
	}

	// Touch all the entries, should be in t1
	for i := 0; i < 128; i++ {
		l.Add(i, i)
	}
	if n := l.t1.Len(); n != 128 {
		t.Fatalf("bad: %d", n)
	}
	if n := l.t2.Len(); n != 0 {
		t.Fatalf("bad: %d", n)
	}

	// Get should upgrade to t2
	for i := 0; i < 128; i++ {
		if _, ok := l.Get(i); !ok {
			t.Fatalf("missing: %d", i)
		}
	}
	if n := l.t1.Len(); n != 0 {
		t.Fatalf("bad: %d", n)
	}
	if n := l.t2.Len(); n != 128 {
		t.Fatalf("bad: %d", n)
	}

	// Get be from t2
	for i := 0; i < 128; i++ {
		if _, ok := l.Get(i); !ok {
			t.Fatalf("missing: %d", i)
		}
	}
	if n := l.t1.Len(); n != 0 {
		t.Fatalf("bad: %d", n)
	}
	if n := l.t2.Len(); n != 128 {
		t.Fatalf("bad: %d", n)
	}
}

func TestARC_Add_RecentToFrequent(t *testing.T) {
	t.Parallel()
	l, err := NewARC[int, int](t, 128)
	if err != nil {
		t.Fatalf("err: %v", err)
	}

	// Add initially to t1
	l.Add(1, 1)
	if n := l.t1.Len(); n != 1 {
		t.Fatalf("bad: %d", n)
	}
	if n := l.t2.Len(); n != 0 {
		t.Fatalf("bad: %d", n)
	}

	// Add should upgrade to t2
	l.Add(1, 1)
	if n := l.t1.Len(); n != 0 {
		t.Fatalf("bad: %d", n)
	}
	if n := l.t2.Len(); n != 1 {
		t.Fatalf("bad: %d", n)
	}

	// Add should remain in t2
	l.Add(1, 1)
	if n := l.t1.Len(); n != 0 {
		t.Fatalf("bad: %d", n)
	}
	if n := l.t2.Len(); n != 1 {
		t.Fatalf("bad: %d", n)
	}
}

func TestARC_Adaptive(t *testing.T) {
	t.Parallel()
	l, err := NewARC[int, int](t, 4)
	if err != nil {
		t.Fatalf("err: %v", err)
	}

	// Fill t1
	for i := 0; i < 4; i++ {
		l.Add(i, i)
	}
	require.Equal(t, `[   _0_    _1_    _2_    _3_ !^]___    ___    ___    ___    `, l.String())
	if n := l.t1.Len(); n != 4 {
		t.Fatalf("bad: %d", n)
	}

	// Move to t2
	l.Get(0)
	require.Equal(t, `    ___[   _1_    _2_    _3_ !^ _0_   ]___    ___    ___    `, l.String())
	l.Get(1)
	require.Equal(t, `    ___    ___[   _2_    _3_ !^ _1_    _0_   ]___    ___    `, l.String())
	if n := l.t2.Len(); n != 2 {
		t.Fatalf("bad: %d", n)
	}

	// Evict from t1
	l.Add(4, 4)
	if n := l.b1.Len(); n != 1 {
		t.Fatalf("bad: %d", n)
	}

	// Current state
	// t1 : (MRU) [4, 3] (LRU)
	// t2 : (MRU) [1, 0] (LRU)
	// b1 : (MRU) [2] (LRU)
	// b2 : (MRU) [] (LRU)
	require.Equal(t, `    ___    _2_[   _3_    _4_ !^ _1_    _0_   ]___    ___    `, l.String())

	// Add 2, should cause hit on b1
	l.Add(2, 2)
	require.Equal(t, `    ___    ___    _3_[ ^ _4_ !  _2_    _1_    _0_   ]___    `, l.String())
	if n := l.b1.Len(); n != 1 {
		t.Fatalf("bad: %d", n)
	}
	if l.p() != 1 {
		t.Fatalf("bad: %d", l.p())
	}
	if n := l.t2.Len(); n != 3 {
		t.Fatalf("bad: %d", n)
	}

	// Current state
	// t1 : (MRU) [4] (LRU)
	// t2 : (MRU) [2, 1, 0] (LRU)
	// b1 : (MRU) [3] (LRU)
	// b2 : (MRU) [] (LRU)
	require.Equal(t, `    ___    ___    _3_[ ^ _4_ !  _2_    _1_    _0_   ]___    `, l.String())

	// Add 4, should migrate to t2
	l.Add(4, 4)
	require.Equal(t, `    ___    ___    ___  ^ _3_[!  _4_    _2_    _1_    _0_   ]`, l.String())
	if n := l.t1.Len(); n != 0 {
		t.Fatalf("bad: %d", n)
	}
	if n := l.t2.Len(); n != 4 {
		t.Fatalf("bad: %d", n)
	}

	// Current state
	// t1 : (MRU) [] (LRU)
	// t2 : (MRU) [4, 2, 1, 0] (LRU)
	// b1 : (MRU) [3] (LRU)
	// b2 : (MRU) [] (LRU)
	require.Equal(t, `    ___    ___    ___  ^ _3_[!  _4_    _2_    _1_    _0_   ]`, l.String())

	// Add 5, should evict to b2
	l.Add(5, 5)
	require.Equal(t, `    ___    ___    _3_[ ^ _5_ !  _4_    _2_    _1_   ]_0_    `, l.String())
	if n := l.t1.Len(); n != 1 {
		t.Fatalf("bad: %d", n)
	}
	if n := l.t2.Len(); n != 3 {
		t.Fatalf("bad: %d", n)
	}
	if n := l.b2.Len(); n != 1 {
		t.Fatalf("bad: %d", n)
	}

	// Current state
	// t1 : (MRU) [5] (LRU)
	// t2 : (MRU) [4, 2, 1] (LRU)
	// b1 : (MRU) [3] (LRU)
	// b2 : (MRU) [0] (LRU)
	require.Equal(t, `    ___    ___    _3_[ ^ _5_ !  _4_    _2_    _1_   ]_0_    `, l.String())

	// Add 0, should decrease p
	l.Add(0, 0)
	require.Equal(t, `    ___    ___    _3_    _5_[!^ _0_    _4_    _2_    _1_   ]`, l.String())
	if n := l.t1.Len(); n != 0 {
		t.Fatalf("bad: %d", n)
	}
	if n := l.t2.Len(); n != 4 {
		t.Fatalf("bad: %d", n)
	}
	if n := l.b1.Len(); n != 2 {
		t.Fatalf("bad: %d", n)
	}
	if n := l.b2.Len(); n != 0 {
		t.Fatalf("bad: %d", n)
	}
	if l.p() != 0 {
		t.Fatalf("bad: %d", l.p())
	}

	// Current state
	// t1 : (MRU) [] (LRU)
	// t2 : (MRU) [0, 4, 2, 1] (LRU)
	// b1 : (MRU) [5, 3] (LRU)
	// b2 : (MRU) [] (LRU)
	require.Equal(t, `    ___    ___    _3_    _5_[!^ _0_    _4_    _2_    _1_   ]`, l.String())
}

func TestARC(t *testing.T) {
	t.Parallel()
	l, err := NewARC[int, int](t, 128)
	if err != nil {
		t.Fatalf("err: %v", err)
	}

	for i := 0; i < 256; i++ {
		l.Add(i, i)
	}
	if l.Len() != 128 {
		t.Fatalf("bad len: %v", l.Len())
	}

	for i, k := range l.Keys() {
		if v, ok := l.Get(k); !ok || v != k || v != i+128 {
			t.Fatalf("bad key: %v", k)
		}
	}
	for i := 0; i < 128; i++ {
		if _, ok := l.Get(i); ok {
			t.Fatalf("should be evicted")
		}
	}
	for i := 128; i < 256; i++ {
		if _, ok := l.Get(i); !ok {
			t.Fatalf("should not be evicted")
		}
	}
	for i := 128; i < 192; i++ {
		l.Remove(i)
		if _, ok := l.Get(i); ok {
			t.Fatalf("should be deleted")
		}
	}

	l.Purge()
	if l.Len() != 0 {
		t.Fatalf("bad len: %v", l.Len())
	}
	if _, ok := l.Get(200); ok {
		t.Fatalf("should contain nothing")
	}
}

// Test that Contains doesn't update recent-ness
func TestARC_Contains(t *testing.T) {
	t.Parallel()
	l, err := NewARC[int, int](t, 2)
	if err != nil {
		t.Fatalf("err: %v", err)
	}

	l.Add(1, 1)
	l.Add(2, 2)
	if !l.Contains(1) {
		t.Errorf("1 should be contained")
	}

	l.Add(3, 3)
	if l.Contains(1) {
		t.Errorf("Contains should not have updated recent-ness of 1")
	}
}

// Test that Peek doesn't update recent-ness
func TestARC_Peek(t *testing.T) {
	t.Parallel()
	l, err := NewARC[int, int](t, 2)
	if err != nil {
		t.Fatalf("err: %v", err)
	}

	l.Add(1, 1)
	l.Add(2, 2)
	if v, ok := l.Peek(1); !ok || v != 1 {
		t.Errorf("1 should be set to 1: %v, %v", v, ok)
	}

	l.Add(3, 3)
	if l.Contains(1) {
		t.Errorf("should not have updated recent-ness of 1")
	}
}