Support multiple downloads

This commit is contained in:
Olivier Tremblay 2018-10-15 07:29:55 -04:00
parent 03442ed63f
commit 2e8083e42f
No known key found for this signature in database
GPG key ID: D1C73ACB855E3A6D
2 changed files with 81 additions and 35 deletions

View file

@ -1,2 +1,11 @@
# sharethis
Easy file sharing for the mere mortal
# TODO
- Better documentation. :P
- Logging (probably with log levels too)
- Per-account bandwidth quotas
- Admin functionality (use the client to add or remove keys, adjust quotas, etc)
- Remove hardcoded idiotic values
- UPnP hole-punching (so that the server does not have to proxy files (as in direct peer-to-peer file sending))

107
main.go
View file

@ -22,12 +22,9 @@ import (
)
type FileReq struct {
Path string
}
type chanpair struct {
sshchan *ssh.Channel
sigchan chan struct{}
Path string
ShareCount uint
serverconn *ssh.ServerConn
}
func main() {
@ -36,6 +33,7 @@ func main() {
remotehost := flag.String("remote", "share.otremblay.com", "remote server for sharethis to contact")
sshport := flag.String("sshport", "2022", "the remote ssh port")
httpport := flag.String("httpport", "8888", "the remote server's http port")
sharecount := flag.Uint("count", 1, "Amount of times you want to share this file")
flag.Parse()
if *server {
runServer("0.0.0.0", *sshport, *httpport, "id_rsa")
@ -107,31 +105,55 @@ func main() {
// In the words of weezer, I've got my hashed path.
// TODO: Get the remote URL from the remote server instead of rebuilding it locally.
// TODO: Clean up the port from the URL if it's 80 or 443
fmt.Println(fmt.Sprintf("http://%s:%s/%s", *remotehost, *httpport, hashedpath))
err = enc.Encode(&FileReq{hashedpath})
var fileurl string
if *httpport == "443" {
fileurl = fmt.Sprintf("https://%s/%s", *remotehost, hashedpath)
} else if *httpport == "80" {
fileurl = fmt.Sprintf("http://%s/%s", *remotehost, hashedpath)
} else {
fileurl = fmt.Sprintf("http://%s:%s/%s", *remotehost, *httpport, hashedpath)
}
fmt.Println(fileurl)
err = enc.Encode(&FileReq{Path: hashedpath, ShareCount: *sharecount})
if err != nil {
fmt.Println(err)
ch.Close()
os.Exit(1)
}
dec := gob.NewDecoder(ch)
var fr *FileReq
mu := &sync.Mutex{}
defer ch.Close()
ncc := connection.HandleChannelOpen(hashedpath)
for {
err := dec.Decode(&fr)
nc := <-ncc
ch, req, err := nc.Accept()
go ssh.DiscardRequests(req)
if err != nil {
fmt.Println(err)
continue
fmt.Fprintln(os.Stderr, err)
}
if fr.Path == hashedpath {
defer ch.Close()
go func() {
f, err := os.Open(path)
if err != nil {
fmt.Fprintln(ch, "Sharethis error")
os.Exit(1)
return
}
io.Copy(ch, f)
return
}
_, err = io.Copy(ch, f)
if err != nil {
fmt.Println(err)
return
}
mu.Lock()
if *sharecount == 0 {
mu.Unlock()
os.Exit(0)
}
*sharecount--
mu.Unlock()
ch.Close()
}()
}
}
@ -159,18 +181,18 @@ func SSHAgent() ssh.AuthMethod {
}
func runServer(host, sshport, httpport, keyfile string) {
filemap := map[string]*ssh.Channel{}
filemap := map[string]*FileReq{}
syncy := &sync.RWMutex{}
mapget := func(key string) (*ssh.Channel, bool) {
mapget := func(key string) (*FileReq, bool) {
syncy.RLock()
defer syncy.RUnlock()
c, ok := filemap[key]
return c, ok
}
mapset := func(key string, channel *ssh.Channel) {
mapset := func(key string, fr *FileReq) {
syncy.Lock()
filemap[key] = channel
filemap[key] = fr
syncy.Unlock()
}
mapdel := func(key string) {
@ -235,7 +257,10 @@ func runServer(host, sshport, httpport, keyfile string) {
if err != nil {
continue
}
mapset(filereq.Path, &channel)
// TODO: also take the sharecount in the map.
filereq.serverconn = serverConn
mapset(filereq.Path, &filereq)
go func() { serverConn.Wait(); mapdel(filereq.Path) }()
return
}
@ -245,18 +270,29 @@ func runServer(host, sshport, httpport, keyfile string) {
}
}()
http.ListenAndServe(fmt.Sprintf(":%s", httpport), http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if strings.Contains(req.UserAgent(), "Slackbot-LinkExpanding") {
rw.WriteHeader(http.StatusForbidden)
fmt.Fprintln(rw, "No slackbots allowed!")
return
}
fp := strings.TrimPrefix(req.URL.Path, "/")
if channel, ok := mapget(fp); ok {
defer func() { (*channel).Close() }()
enc := gob.NewEncoder(*channel)
err := enc.Encode(&FileReq{fp})
if fr, ok := mapget(fp); ok {
channel, req, err := fr.serverconn.OpenChannel(fp, nil)
if err != nil {
fmt.Fprintln(rw, err)
fmt.Fprintln(os.Stderr, err)
return
}
io.Copy(rw, *channel)
mapdel(fp)
go ssh.DiscardRequests(req)
_, err = io.Copy(rw, channel)
if err != nil {
fmt.Println(err)
}
if fr.ShareCount == 0 {
mapdel(fp)
channel.Close()
}
fr.ShareCount--
return
} else {
rw.WriteHeader(http.StatusNotFound)
@ -273,14 +309,15 @@ func buildCfg() *ssh.ServerConfig {
log.Fatalf("Failed to load authorized_keys, err: %v", err)
}
authorizedKeysMap := map[string]bool{}
authorizedKeysMap := map[string]string{}
for len(authorizedKeysBytes) > 0 {
pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes)
pubKey, comment, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes)
fmt.Println(comment)
if err != nil {
log.Fatal(err)
}
authorizedKeysMap[string(pubKey.Marshal())] = true
authorizedKeysMap[string(pubKey.Marshal())] = comment
authorizedKeysBytes = rest
}
@ -289,7 +326,7 @@ func buildCfg() *ssh.ServerConfig {
cfg.SetDefaults()
cfg.PasswordCallback = func(ssh.ConnMetadata, []byte) (*ssh.Permissions, error) { return nil, fmt.Errorf("Public key only") }
cfg.PublicKeyCallback = func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
if authorizedKeysMap[string(key.Marshal())] {
if _, ok := authorizedKeysMap[string(key.Marshal())]; ok {
return nil, nil
}
return nil, fmt.Errorf("unknown public key for %q", conn.User())