Make map access less racy

This commit is contained in:
Olivier Tremblay 2018-10-12 19:04:45 -04:00
parent 602d9e8809
commit 70005097b9
No known key found for this signature in database
GPG key ID: D1C73ACB855E3A6D

27
main.go
View file

@ -14,6 +14,7 @@ import (
"os/user" "os/user"
"path/filepath" "path/filepath"
"strings" "strings"
"sync"
"syscall" "syscall"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
@ -153,6 +154,24 @@ func SSHAgent() ssh.AuthMethod {
func runServer(host, port, keyfile string) { func runServer(host, port, keyfile string) {
filemap := map[string]*ssh.Channel{} filemap := map[string]*ssh.Channel{}
syncy := &sync.RWMutex{}
mapget := func(key string) (*ssh.Channel, bool) {
syncy.RLock()
defer syncy.RUnlock()
c, ok := filemap[key]
return c, ok
}
mapset := func(key string, channel *ssh.Channel) {
syncy.Lock()
filemap[key] = channel
syncy.Unlock()
}
mapdel := func(key string) {
syncy.Lock()
delete(filemap, key)
syncy.Unlock()
}
cfg := buildCfg() cfg := buildCfg()
@ -210,8 +229,8 @@ func runServer(host, port, keyfile string) {
if err != nil { if err != nil {
continue continue
} }
filemap[filereq.Path] = &channel mapset(filereq.Path, &channel)
go func() { serverConn.Wait(); delete(filemap, filereq.Path) }() go func() { serverConn.Wait(); mapdel(filereq.Path) }()
return return
} }
}() }()
@ -221,7 +240,7 @@ func runServer(host, port, keyfile string) {
}() }()
http.ListenAndServe(":8888", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { http.ListenAndServe(":8888", http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
fp := strings.TrimPrefix(req.URL.Path, "/") fp := strings.TrimPrefix(req.URL.Path, "/")
if channel, ok := filemap[fp]; ok { if channel, ok := mapget(fp); ok {
defer func() { (*channel).Close() }() defer func() { (*channel).Close() }()
enc := gob.NewEncoder(*channel) enc := gob.NewEncoder(*channel)
err := enc.Encode(&FileReq{fp}) err := enc.Encode(&FileReq{fp})
@ -230,7 +249,7 @@ func runServer(host, port, keyfile string) {
return return
} }
io.Copy(rw, *channel) io.Copy(rw, *channel)
delete(filemap, fp) mapdel(fp)
return return
} else { } else {