Support multiple downloads

This commit is contained in:
Olivier Tremblay 2018-10-15 07:29:55 -04:00
parent 03442ed63f
commit 2e8083e42f
No known key found for this signature in database
GPG key ID: D1C73ACB855E3A6D
2 changed files with 81 additions and 35 deletions

View file

@ -1,2 +1,11 @@
# sharethis # sharethis
Easy file sharing for the mere mortal Easy file sharing for the mere mortal
# TODO
- Better documentation. :P
- Logging (probably with log levels too)
- Per-account bandwidth quotas
- Admin functionality (use the client to add or remove keys, adjust quotas, etc)
- Remove hardcoded idiotic values
- UPnP hole-punching (so that the server does not have to proxy files (as in direct peer-to-peer file sending))

107
main.go
View file

@ -22,12 +22,9 @@ import (
) )
type FileReq struct { type FileReq struct {
Path string Path string
} ShareCount uint
serverconn *ssh.ServerConn
type chanpair struct {
sshchan *ssh.Channel
sigchan chan struct{}
} }
func main() { func main() {
@ -36,6 +33,7 @@ func main() {
remotehost := flag.String("remote", "share.otremblay.com", "remote server for sharethis to contact") remotehost := flag.String("remote", "share.otremblay.com", "remote server for sharethis to contact")
sshport := flag.String("sshport", "2022", "the remote ssh port") sshport := flag.String("sshport", "2022", "the remote ssh port")
httpport := flag.String("httpport", "8888", "the remote server's http port") httpport := flag.String("httpport", "8888", "the remote server's http port")
sharecount := flag.Uint("count", 1, "Amount of times you want to share this file")
flag.Parse() flag.Parse()
if *server { if *server {
runServer("0.0.0.0", *sshport, *httpport, "id_rsa") runServer("0.0.0.0", *sshport, *httpport, "id_rsa")
@ -107,31 +105,55 @@ func main() {
// In the words of weezer, I've got my hashed path. // In the words of weezer, I've got my hashed path.
// TODO: Get the remote URL from the remote server instead of rebuilding it locally. // TODO: Get the remote URL from the remote server instead of rebuilding it locally.
// TODO: Clean up the port from the URL if it's 80 or 443 // TODO: Clean up the port from the URL if it's 80 or 443
fmt.Println(fmt.Sprintf("http://%s:%s/%s", *remotehost, *httpport, hashedpath)) var fileurl string
err = enc.Encode(&FileReq{hashedpath}) if *httpport == "443" {
fileurl = fmt.Sprintf("https://%s/%s", *remotehost, hashedpath)
} else if *httpport == "80" {
fileurl = fmt.Sprintf("http://%s/%s", *remotehost, hashedpath)
} else {
fileurl = fmt.Sprintf("http://%s:%s/%s", *remotehost, *httpport, hashedpath)
}
fmt.Println(fileurl)
err = enc.Encode(&FileReq{Path: hashedpath, ShareCount: *sharecount})
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
ch.Close() ch.Close()
os.Exit(1) os.Exit(1)
} }
dec := gob.NewDecoder(ch)
var fr *FileReq mu := &sync.Mutex{}
defer ch.Close()
ncc := connection.HandleChannelOpen(hashedpath)
for { for {
err := dec.Decode(&fr) nc := <-ncc
ch, req, err := nc.Accept()
go ssh.DiscardRequests(req)
if err != nil { if err != nil {
fmt.Println(err) fmt.Fprintln(os.Stderr, err)
continue
} }
if fr.Path == hashedpath {
defer ch.Close() go func() {
f, err := os.Open(path) f, err := os.Open(path)
if err != nil { if err != nil {
fmt.Fprintln(ch, "Sharethis error") fmt.Fprintln(ch, "Sharethis error")
os.Exit(1) return
} }
io.Copy(ch, f) _, err = io.Copy(ch, f)
return if err != nil {
} fmt.Println(err)
return
}
mu.Lock()
if *sharecount == 0 {
mu.Unlock()
os.Exit(0)
}
*sharecount--
mu.Unlock()
ch.Close()
}()
} }
} }
@ -159,18 +181,18 @@ func SSHAgent() ssh.AuthMethod {
} }
func runServer(host, sshport, httpport, keyfile string) { func runServer(host, sshport, httpport, keyfile string) {
filemap := map[string]*ssh.Channel{} filemap := map[string]*FileReq{}
syncy := &sync.RWMutex{} syncy := &sync.RWMutex{}
mapget := func(key string) (*ssh.Channel, bool) { mapget := func(key string) (*FileReq, bool) {
syncy.RLock() syncy.RLock()
defer syncy.RUnlock() defer syncy.RUnlock()
c, ok := filemap[key] c, ok := filemap[key]
return c, ok return c, ok
} }
mapset := func(key string, channel *ssh.Channel) { mapset := func(key string, fr *FileReq) {
syncy.Lock() syncy.Lock()
filemap[key] = channel filemap[key] = fr
syncy.Unlock() syncy.Unlock()
} }
mapdel := func(key string) { mapdel := func(key string) {
@ -235,7 +257,10 @@ func runServer(host, sshport, httpport, keyfile string) {
if err != nil { if err != nil {
continue continue
} }
mapset(filereq.Path, &channel)
// TODO: also take the sharecount in the map.
filereq.serverconn = serverConn
mapset(filereq.Path, &filereq)
go func() { serverConn.Wait(); mapdel(filereq.Path) }() go func() { serverConn.Wait(); mapdel(filereq.Path) }()
return return
} }
@ -245,18 +270,29 @@ func runServer(host, sshport, httpport, keyfile string) {
} }
}() }()
http.ListenAndServe(fmt.Sprintf(":%s", httpport), http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { http.ListenAndServe(fmt.Sprintf(":%s", httpport), http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if strings.Contains(req.UserAgent(), "Slackbot-LinkExpanding") {
rw.WriteHeader(http.StatusForbidden)
fmt.Fprintln(rw, "No slackbots allowed!")
return
}
fp := strings.TrimPrefix(req.URL.Path, "/") fp := strings.TrimPrefix(req.URL.Path, "/")
if channel, ok := mapget(fp); ok { if fr, ok := mapget(fp); ok {
defer func() { (*channel).Close() }() channel, req, err := fr.serverconn.OpenChannel(fp, nil)
enc := gob.NewEncoder(*channel)
err := enc.Encode(&FileReq{fp})
if err != nil { if err != nil {
fmt.Fprintln(rw, err) fmt.Fprintln(os.Stderr, err)
return return
} }
io.Copy(rw, *channel) go ssh.DiscardRequests(req)
mapdel(fp)
_, err = io.Copy(rw, channel)
if err != nil {
fmt.Println(err)
}
if fr.ShareCount == 0 {
mapdel(fp)
channel.Close()
}
fr.ShareCount--
return return
} else { } else {
rw.WriteHeader(http.StatusNotFound) rw.WriteHeader(http.StatusNotFound)
@ -273,14 +309,15 @@ func buildCfg() *ssh.ServerConfig {
log.Fatalf("Failed to load authorized_keys, err: %v", err) log.Fatalf("Failed to load authorized_keys, err: %v", err)
} }
authorizedKeysMap := map[string]bool{} authorizedKeysMap := map[string]string{}
for len(authorizedKeysBytes) > 0 { for len(authorizedKeysBytes) > 0 {
pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes) pubKey, comment, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes)
fmt.Println(comment)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
authorizedKeysMap[string(pubKey.Marshal())] = true authorizedKeysMap[string(pubKey.Marshal())] = comment
authorizedKeysBytes = rest authorizedKeysBytes = rest
} }
@ -289,7 +326,7 @@ func buildCfg() *ssh.ServerConfig {
cfg.SetDefaults() cfg.SetDefaults()
cfg.PasswordCallback = func(ssh.ConnMetadata, []byte) (*ssh.Permissions, error) { return nil, fmt.Errorf("Public key only") } cfg.PasswordCallback = func(ssh.ConnMetadata, []byte) (*ssh.Permissions, error) { return nil, fmt.Errorf("Public key only") }
cfg.PublicKeyCallback = func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { cfg.PublicKeyCallback = func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
if authorizedKeysMap[string(key.Marshal())] { if _, ok := authorizedKeysMap[string(key.Marshal())]; ok {
return nil, nil return nil, nil
} }
return nil, fmt.Errorf("unknown public key for %q", conn.User()) return nil, fmt.Errorf("unknown public key for %q", conn.User())