From cf6bf03d33e9a36d138db0e4f7e165e057ffcaf8 Mon Sep 17 00:00:00 2001 From: xing Date: Fri, 20 Jan 2023 18:43:11 +0800 Subject: [PATCH] =?UTF-8?q?=E7=BC=93=E5=AD=98=E8=B0=83=E6=95=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cache/vars.go | 7 +++-- cache/vars_test.go | 66 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+), 2 deletions(-) create mode 100644 cache/vars_test.go diff --git a/cache/vars.go b/cache/vars.go index 9324d81..42a719b 100644 --- a/cache/vars.go +++ b/cache/vars.go @@ -42,10 +42,13 @@ func (c *VarCache[T]) IsExpired() bool { } func (c *VarCache[T]) Flush() { - mu := c.v.Load().mutex + v := c.v.Load() + mu := v.mutex mu.Lock() defer mu.Unlock() - c.v.Delete() + var vv T + v.data = vv + c.v.Store(v) } func (c *VarCache[T]) GetCache(ctx context.Context, timeout time.Duration, params ...any) (T, error) { diff --git a/cache/vars_test.go b/cache/vars_test.go new file mode 100644 index 0000000..846f7f6 --- /dev/null +++ b/cache/vars_test.go @@ -0,0 +1,66 @@ +package cache + +import ( + "context" + "fmt" + "testing" + "time" +) + +var cc = *NewVarCache(func(a ...any) (int, error) { + return 1, nil +}, time.Minute) + +func TestVarCache_Flush(t *testing.T) { + type testCase[T any] struct { + name string + c VarCache[T] + } + tests := []testCase[int]{ + { + name: "1", + c: cc, + }, + } + c := context.Background() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fmt.Println(tt.c.GetCache(c, time.Second)) + tt.c.Flush() + fmt.Println(tt.c.GetCache(c, time.Second)) + }) + } +} + +func TestVarCache_IsExpired(t *testing.T) { + type testCase[T any] struct { + name string + c VarCache[T] + want bool + } + tests := []testCase[int]{ + { + name: "expired", + c: cc, + want: true, + }, + { + name: "not expired", + c: func() VarCache[int] { + v := *NewVarCache(func(a ...any) (int, error) { + return 1, nil + }, time.Minute) + _, _ = v.GetCache(context.Background(), time.Second) + return v + }(), + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.c.IsExpired(); got != tt.want { + t.Errorf("IsExpired() = %v, want %v", got, tt.want) + } + }) + } +}