mirror of
				https://github.com/ehang-io/nps
				synced 2025-10-26 17:35:42 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			202 lines
		
	
	
		
			4.3 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			202 lines
		
	
	
		
			4.3 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package file
 | |
| 
 | |
| import (
 | |
| 	"encoding/json"
 | |
| 	"errors"
 | |
| 	"github.com/astaxie/beego/logs"
 | |
| 	"os"
 | |
| 	"path/filepath"
 | |
| 	"strings"
 | |
| 	"sync"
 | |
| 	"sync/atomic"
 | |
| 
 | |
| 	"ehang.io/nps/lib/common"
 | |
| 	"ehang.io/nps/lib/rate"
 | |
| )
 | |
| 
 | |
| func NewJsonDb(runPath string) *JsonDb {
 | |
| 	return &JsonDb{
 | |
| 		RunPath:        runPath,
 | |
| 		TaskFilePath:   filepath.Join(runPath, "conf", "tasks.json"),
 | |
| 		HostFilePath:   filepath.Join(runPath, "conf", "hosts.json"),
 | |
| 		ClientFilePath: filepath.Join(runPath, "conf", "clients.json"),
 | |
| 	}
 | |
| }
 | |
| 
 | |
| type JsonDb struct {
 | |
| 	Tasks            sync.Map
 | |
| 	Hosts            sync.Map
 | |
| 	HostsTmp         sync.Map
 | |
| 	Clients          sync.Map
 | |
| 	RunPath          string
 | |
| 	ClientIncreaseId int32  //client increased id
 | |
| 	TaskIncreaseId   int32  //task increased id
 | |
| 	HostIncreaseId   int32  //host increased id
 | |
| 	TaskFilePath     string //task file path
 | |
| 	HostFilePath     string //host file path
 | |
| 	ClientFilePath   string //client file path
 | |
| }
 | |
| 
 | |
| func (s *JsonDb) LoadTaskFromJsonFile() {
 | |
| 	loadSyncMapFromFile(s.TaskFilePath, func(v string) {
 | |
| 		var err error
 | |
| 		post := new(Tunnel)
 | |
| 		if json.Unmarshal([]byte(v), &post) != nil {
 | |
| 			return
 | |
| 		}
 | |
| 		if post.Client, err = s.GetClient(post.Client.Id); err != nil {
 | |
| 			return
 | |
| 		}
 | |
| 		s.Tasks.Store(post.Id, post)
 | |
| 		if post.Id > int(s.TaskIncreaseId) {
 | |
| 			s.TaskIncreaseId = int32(post.Id)
 | |
| 		}
 | |
| 	})
 | |
| }
 | |
| 
 | |
| func (s *JsonDb) LoadClientFromJsonFile() {
 | |
| 	loadSyncMapFromFile(s.ClientFilePath, func(v string) {
 | |
| 		post := new(Client)
 | |
| 		if json.Unmarshal([]byte(v), &post) != nil {
 | |
| 			return
 | |
| 		}
 | |
| 		if post.RateLimit > 0 {
 | |
| 			post.Rate = rate.NewRate(int64(post.RateLimit * 1024))
 | |
| 		} else {
 | |
| 			post.Rate = rate.NewRate(int64(2 << 23))
 | |
| 		}
 | |
| 		post.Rate.Start()
 | |
| 		post.NowConn = 0
 | |
| 		s.Clients.Store(post.Id, post)
 | |
| 		if post.Id > int(s.ClientIncreaseId) {
 | |
| 			s.ClientIncreaseId = int32(post.Id)
 | |
| 		}
 | |
| 	})
 | |
| }
 | |
| 
 | |
| func (s *JsonDb) LoadHostFromJsonFile() {
 | |
| 	loadSyncMapFromFile(s.HostFilePath, func(v string) {
 | |
| 		var err error
 | |
| 		post := new(Host)
 | |
| 		if json.Unmarshal([]byte(v), &post) != nil {
 | |
| 			return
 | |
| 		}
 | |
| 		if post.Client, err = s.GetClient(post.Client.Id); err != nil {
 | |
| 			return
 | |
| 		}
 | |
| 		s.Hosts.Store(post.Id, post)
 | |
| 		if post.Id > int(s.HostIncreaseId) {
 | |
| 			s.HostIncreaseId = int32(post.Id)
 | |
| 		}
 | |
| 	})
 | |
| }
 | |
| 
 | |
| func (s *JsonDb) GetClient(id int) (c *Client, err error) {
 | |
| 	if v, ok := s.Clients.Load(id); ok {
 | |
| 		c = v.(*Client)
 | |
| 		return
 | |
| 	}
 | |
| 	err = errors.New("未找到客户端")
 | |
| 	return
 | |
| }
 | |
| 
 | |
| var hostLock sync.Mutex
 | |
| 
 | |
| func (s *JsonDb) StoreHostToJsonFile() {
 | |
| 	hostLock.Lock()
 | |
| 	storeSyncMapToFile(s.Hosts, s.HostFilePath)
 | |
| 	hostLock.Unlock()
 | |
| }
 | |
| 
 | |
| var taskLock sync.Mutex
 | |
| 
 | |
| func (s *JsonDb) StoreTasksToJsonFile() {
 | |
| 	taskLock.Lock()
 | |
| 	storeSyncMapToFile(s.Tasks, s.TaskFilePath)
 | |
| 	taskLock.Unlock()
 | |
| }
 | |
| 
 | |
| var clientLock sync.Mutex
 | |
| 
 | |
| func (s *JsonDb) StoreClientsToJsonFile() {
 | |
| 	clientLock.Lock()
 | |
| 	storeSyncMapToFile(s.Clients, s.ClientFilePath)
 | |
| 	clientLock.Unlock()
 | |
| }
 | |
| 
 | |
| func (s *JsonDb) GetClientId() int32 {
 | |
| 	return atomic.AddInt32(&s.ClientIncreaseId, 1)
 | |
| }
 | |
| 
 | |
| func (s *JsonDb) GetTaskId() int32 {
 | |
| 	return atomic.AddInt32(&s.TaskIncreaseId, 1)
 | |
| }
 | |
| 
 | |
| func (s *JsonDb) GetHostId() int32 {
 | |
| 	return atomic.AddInt32(&s.HostIncreaseId, 1)
 | |
| }
 | |
| 
 | |
| func loadSyncMapFromFile(filePath string, f func(value string)) {
 | |
| 	b, err := common.ReadAllFromFile(filePath)
 | |
| 	if err != nil {
 | |
| 		panic(err)
 | |
| 	}
 | |
| 	for _, v := range strings.Split(string(b), "\n"+common.CONN_DATA_SEQ) {
 | |
| 		f(v)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func storeSyncMapToFile(m sync.Map, filePath string) {
 | |
| 	file, err := os.Create(filePath + ".tmp")
 | |
| 	// first create a temporary file to store
 | |
| 	if err != nil {
 | |
| 		panic(err)
 | |
| 	}
 | |
| 	m.Range(func(key, value interface{}) bool {
 | |
| 		var b []byte
 | |
| 		var err error
 | |
| 		switch value.(type) {
 | |
| 		case *Tunnel:
 | |
| 			obj := value.(*Tunnel)
 | |
| 			if obj.NoStore {
 | |
| 				return true
 | |
| 			}
 | |
| 			b, err = json.Marshal(obj)
 | |
| 		case *Host:
 | |
| 			obj := value.(*Host)
 | |
| 			if obj.NoStore {
 | |
| 				return true
 | |
| 			}
 | |
| 			b, err = json.Marshal(obj)
 | |
| 		case *Client:
 | |
| 			obj := value.(*Client)
 | |
| 			if obj.NoStore {
 | |
| 				return true
 | |
| 			}
 | |
| 			b, err = json.Marshal(obj)
 | |
| 		default:
 | |
| 			return true
 | |
| 		}
 | |
| 		if err != nil {
 | |
| 			return true
 | |
| 		}
 | |
| 		_, err = file.Write(b)
 | |
| 		if err != nil {
 | |
| 			panic(err)
 | |
| 		}
 | |
| 		_, err = file.Write([]byte("\n" + common.CONN_DATA_SEQ))
 | |
| 		if err != nil {
 | |
| 			panic(err)
 | |
| 		}
 | |
| 		return true
 | |
| 	})
 | |
| 	_ = file.Sync()
 | |
| 	_ = file.Close()
 | |
| 	// must close file first, then rename it
 | |
| 	err = os.Rename(filePath+".tmp", filePath)
 | |
| 	if err != nil {
 | |
| 		logs.Error(err, "store to file err, data will lost")
 | |
| 	}
 | |
| 	// replace the file, maybe provides atomic operation
 | |
| }
 | 
