Skip to content

backport: retry restarting HNS (#3529, #3540) #3563

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: release/v1.5
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 93 additions & 17 deletions platform/os_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/Azure/azure-container-networking/log"
"github.com/Azure/azure-container-networking/platform/windows/adapter"
"github.com/Azure/azure-container-networking/platform/windows/adapter/mellanox"
"github.com/avast/retry-go/v4"
"github.com/pkg/errors"
"go.uber.org/zap"
"golang.org/x/sys/windows"
Expand Down Expand Up @@ -232,32 +233,107 @@ func restartHNS(ctx context.Context) error {
}
defer service.Close()
// Stop the service
_, err = service.Control(svc.Stop)
if err != nil {
return errors.Wrap(err, "could not stop service")
log.Printf("Stopping HNS service")
_ = retry.Do(
tryStopServiceFn(ctx, service),
retry.UntilSucceeded(),
retry.Context(ctx),
retry.DelayType(retry.BackOffDelay),
)
// Start the service again
log.Printf("Starting HNS service")
_ = retry.Do(
tryStartServiceFn(ctx, service),
retry.UntilSucceeded(),
retry.Context(ctx),
retry.DelayType(retry.BackOffDelay),
)
log.Printf("HNS service started")
return nil
}

type managedService interface {
Control(control svc.Cmd) (svc.Status, error)
Query() (svc.Status, error)
Start(args ...string) error
}

func tryStartServiceFn(ctx context.Context, service managedService) func() error {
shouldStart := func(state svc.State) bool {
return !(state == svc.Running || state == svc.StartPending)
}
// Wait for the service to stop
ticker := time.NewTicker(500 * time.Millisecond) //nolint:gomnd // 500ms
defer ticker.Stop()
for { // hacky cancellable do-while
return func() error {
status, err := service.Query()
if err != nil {
return errors.Wrap(err, "could not query service status")
}
if status.State == svc.Stopped {
break
if shouldStart(status.State) {
err = service.Start()
if err != nil {
return errors.Wrap(err, "could not start service")
}
}
select {
case <-ctx.Done():
return errors.New("context cancelled")
case <-ticker.C:
// Wait for the service to start
deadline, cancel := context.WithTimeout(ctx, 90*time.Second)
defer cancel()
ticker := time.NewTicker(500 * time.Millisecond) //nolint:gomnd // 500ms
defer ticker.Stop()
for {
status, err := service.Query()
if err != nil {
return errors.Wrap(err, "could not query service status")
}
if status.State == svc.Running {
log.Printf("service started")
break
}
select {
case <-deadline.Done():
return deadline.Err() //nolint:wrapcheck // error has sufficient context
case <-ticker.C:
}
}
return nil
}
// Start the service again
if err := service.Start(); err != nil {
return errors.Wrap(err, "could not start service")
}

func tryStopServiceFn(ctx context.Context, service managedService) func() error {
shouldStop := func(state svc.State) bool {
return !(state == svc.Stopped || state == svc.StopPending)
}
return func() error {
status, err := service.Query()
if err != nil {
return errors.Wrap(err, "could not query service status")
}
if shouldStop(status.State) {
_, err = service.Control(svc.Stop)
if err != nil {
return errors.Wrap(err, "could not stop service")
}
}
// Wait for the service to stop
deadline, cancel := context.WithTimeout(ctx, 90*time.Second)
defer cancel()
ticker := time.NewTicker(500 * time.Millisecond) //nolint:gomnd // 500ms
defer ticker.Stop()
for {
status, err := service.Query()
if err != nil {
return errors.Wrap(err, "could not query service status")
}
if status.State == svc.Stopped {
log.Printf("service stopped")
break
}
select {
case <-deadline.Done():
return deadline.Err() //nolint:wrapcheck // error has sufficient context
case <-ticker.C:
}
}
return nil
}
return nil
}

func HasMellanoxAdapter() bool {
Expand Down
204 changes: 204 additions & 0 deletions platform/os_windows_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package platform

import (
"context"
"errors"
"os/exec"
"testing"
Expand All @@ -9,6 +10,7 @@ import (
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sys/windows/svc"
)

var errTestFailure = errors.New("test failure")
Expand Down Expand Up @@ -98,3 +100,205 @@ func TestExecuteCommandError(t *testing.T) {
assert.ErrorAs(t, err, &xErr)
assert.Equal(t, 1, xErr.ExitCode())
}

type mockManagedService struct {
queryFuncs []func() (svc.Status, error)
controlFunc func(svc.Cmd) (svc.Status, error)
startFunc func(args ...string) error
}

func (m *mockManagedService) Query() (svc.Status, error) {
Copy link
Preview

Copilot AI Apr 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding a check to verify that 'queryFuncs' is not empty before accessing its first element to avoid potential panics in production code.

Suggested change
func (m *mockManagedService) Query() (svc.Status, error) {
func (m *mockManagedService) Query() (svc.Status, error) {
if len(m.queryFuncs) == 0 {
return svc.Status{}, errors.New("no query functions available")
}

Copilot is powered by AI, so mistakes are possible. Review output carefully before use.

queryFunc := m.queryFuncs[0]
m.queryFuncs = m.queryFuncs[1:]
return queryFunc()
}

func (m *mockManagedService) Control(cmd svc.Cmd) (svc.Status, error) {
return m.controlFunc(cmd)
}

func (m *mockManagedService) Start(args ...string) error {
return m.startFunc(args...)
}

func TestTryStopServiceFn(t *testing.T) {
tests := []struct {
name string
queryFuncs []func() (svc.Status, error)
controlFunc func(svc.Cmd) (svc.Status, error)
expectError bool
}{
{
name: "Service already stopped",
queryFuncs: []func() (svc.Status, error){
func() (svc.Status, error) {
return svc.Status{State: svc.Stopped}, nil
},
func() (svc.Status, error) {
return svc.Status{State: svc.Stopped}, nil
},
},
controlFunc: nil,
expectError: false,
},
{
name: "Service running and stops successfully",
queryFuncs: []func() (svc.Status, error){
func() (svc.Status, error) {
return svc.Status{State: svc.Running}, nil
},
func() (svc.Status, error) {
return svc.Status{State: svc.Stopped}, nil
},
},
controlFunc: func(svc.Cmd) (svc.Status, error) {
return svc.Status{State: svc.Stopped}, nil
},
expectError: false,
},
{
name: "Service running and stops after multiple attempts",
queryFuncs: []func() (svc.Status, error){
func() (svc.Status, error) {
return svc.Status{State: svc.Running}, nil
},
func() (svc.Status, error) {
return svc.Status{State: svc.Running}, nil
},
func() (svc.Status, error) {
return svc.Status{State: svc.Running}, nil
},
func() (svc.Status, error) {
return svc.Status{State: svc.Stopped}, nil
},
},
controlFunc: func(svc.Cmd) (svc.Status, error) {
return svc.Status{State: svc.Stopped}, nil
},
expectError: false,
},
{
name: "Service running and fails to stop",
queryFuncs: []func() (svc.Status, error){
func() (svc.Status, error) {
return svc.Status{State: svc.Running}, nil
},
},
controlFunc: func(svc.Cmd) (svc.Status, error) {
return svc.Status{State: svc.Running}, errors.New("failed to stop service") //nolint:err113 // test error
},
expectError: true,
},
{
name: "Service query fails",
queryFuncs: []func() (svc.Status, error){
func() (svc.Status, error) {
return svc.Status{}, errors.New("failed to query service status") //nolint:err113 // test error
},
},
controlFunc: nil,
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
service := &mockManagedService{
queryFuncs: tt.queryFuncs,
controlFunc: tt.controlFunc,
}
err := tryStopServiceFn(context.Background(), service)()
if tt.expectError {
assert.Error(t, err)
return
}
assert.NoError(t, err)
})
}
}

func TestTryStartServiceFn(t *testing.T) {
tests := []struct {
name string
queryFuncs []func() (svc.Status, error)
startFunc func(...string) error
expectError bool
}{
{
name: "Service already running",
queryFuncs: []func() (svc.Status, error){
func() (svc.Status, error) {
return svc.Status{State: svc.Running}, nil
},
func() (svc.Status, error) {
return svc.Status{State: svc.Running}, nil
},
},
startFunc: nil,
expectError: false,
},
{
name: "Service already starting",
queryFuncs: []func() (svc.Status, error){
func() (svc.Status, error) {
return svc.Status{State: svc.StartPending}, nil
},
func() (svc.Status, error) {
return svc.Status{State: svc.Running}, nil
},
},
startFunc: nil,
expectError: false,
},
{
name: "Service starts successfully",
queryFuncs: []func() (svc.Status, error){
func() (svc.Status, error) {
return svc.Status{State: svc.Stopped}, nil
},
func() (svc.Status, error) {
return svc.Status{State: svc.Running}, nil
},
},
startFunc: func(...string) error {
return nil
},
expectError: false,
},
{
name: "Service fails to start",
queryFuncs: []func() (svc.Status, error){
func() (svc.Status, error) {
return svc.Status{State: svc.Stopped}, nil
},
},
startFunc: func(...string) error {
return errors.New("failed to start service") //nolint:err113 // test error
},
expectError: true,
},
{
name: "Service query fails",
queryFuncs: []func() (svc.Status, error){
func() (svc.Status, error) {
return svc.Status{}, errors.New("failed to query service status") //nolint:err113 // test error
},
},
startFunc: nil,
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
service := &mockManagedService{
queryFuncs: tt.queryFuncs,
startFunc: tt.startFunc,
}
err := tryStartServiceFn(context.Background(), service)()
if tt.expectError {
assert.Error(t, err)
return
}
assert.NoError(t, err)
})
}
}
Loading