From 70005097b9400df11340dfe2d90fb140720c4686 Mon Sep 17 00:00:00 2001 From: Olivier Tremblay Date: Fri, 12 Oct 2018 19:04:45 -0400 Subject: [PATCH] Make map access less racy --- main.go | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/main.go b/main.go index c75e889..5a02e7c 100644 --- a/main.go +++ b/main.go @@ -14,6 +14,7 @@ import ( "os/user" "path/filepath" "strings" + "sync" "syscall" "golang.org/x/crypto/ssh" @@ -153,6 +154,24 @@ func SSHAgent() ssh.AuthMethod { func runServer(host, port, keyfile string) { filemap := map[string]*ssh.Channel{} + syncy := &sync.RWMutex{} + + mapget := func(key string) (*ssh.Channel, bool) { + syncy.RLock() + defer syncy.RUnlock() + c, ok := filemap[key] + return c, ok + } + mapset := func(key string, channel *ssh.Channel) { + syncy.Lock() + filemap[key] = channel + syncy.Unlock() + } + mapdel := func(key string) { + syncy.Lock() + delete(filemap, key) + syncy.Unlock() + } cfg := buildCfg() @@ -210,8 +229,8 @@ func runServer(host, port, keyfile string) { if err != nil { continue } - filemap[filereq.Path] = &channel - go func() { serverConn.Wait(); delete(filemap, filereq.Path) }() + mapset(filereq.Path, &channel) + go func() { serverConn.Wait(); mapdel(filereq.Path) }() return } }() @@ -221,7 +240,7 @@ func runServer(host, port, keyfile string) { }() http.ListenAndServe(":8888", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { fp := strings.TrimPrefix(req.URL.Path, "/") - if channel, ok := filemap[fp]; ok { + if channel, ok := mapget(fp); ok { defer func() { (*channel).Close() }() enc := gob.NewEncoder(*channel) err := enc.Encode(&FileReq{fp}) @@ -230,7 +249,7 @@ func runServer(host, port, keyfile string) { return } io.Copy(rw, *channel) - delete(filemap, fp) + mapdel(fp) return } else {