diff --git a/README.md b/README.md index cbd9193..5f0197c 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,11 @@ # sharethis Easy file sharing for the mere mortal + +# TODO + +- Better documentation. :P +- Logging (probably with log levels too) +- Per-account bandwidth quotas +- Admin functionality (use the client to add or remove keys, adjust quotas, etc) +- Remove hardcoded idiotic values +- UPnP hole-punching (so that the server does not have to proxy files (as in direct peer-to-peer file sending)) \ No newline at end of file diff --git a/main.go b/main.go index 1d635bc..bb07545 100644 --- a/main.go +++ b/main.go @@ -22,12 +22,9 @@ import ( ) type FileReq struct { - Path string -} - -type chanpair struct { - sshchan *ssh.Channel - sigchan chan struct{} + Path string + ShareCount uint + serverconn *ssh.ServerConn } func main() { @@ -36,6 +33,7 @@ func main() { remotehost := flag.String("remote", "share.otremblay.com", "remote server for sharethis to contact") sshport := flag.String("sshport", "2022", "the remote ssh port") httpport := flag.String("httpport", "8888", "the remote server's http port") + sharecount := flag.Uint("count", 1, "Amount of times you want to share this file") flag.Parse() if *server { runServer("0.0.0.0", *sshport, *httpport, "id_rsa") @@ -107,31 +105,55 @@ func main() { // In the words of weezer, I've got my hashed path. // TODO: Get the remote URL from the remote server instead of rebuilding it locally. // TODO: Clean up the port from the URL if it's 80 or 443 - fmt.Println(fmt.Sprintf("http://%s:%s/%s", *remotehost, *httpport, hashedpath)) - err = enc.Encode(&FileReq{hashedpath}) + var fileurl string + if *httpport == "443" { + fileurl = fmt.Sprintf("https://%s/%s", *remotehost, hashedpath) + } else if *httpport == "80" { + fileurl = fmt.Sprintf("http://%s/%s", *remotehost, hashedpath) + } else { + fileurl = fmt.Sprintf("http://%s:%s/%s", *remotehost, *httpport, hashedpath) + } + fmt.Println(fileurl) + err = enc.Encode(&FileReq{Path: hashedpath, ShareCount: *sharecount}) if err != nil { fmt.Println(err) ch.Close() os.Exit(1) } - dec := gob.NewDecoder(ch) - var fr *FileReq + + mu := &sync.Mutex{} + + defer ch.Close() + ncc := connection.HandleChannelOpen(hashedpath) for { - err := dec.Decode(&fr) + nc := <-ncc + ch, req, err := nc.Accept() + go ssh.DiscardRequests(req) if err != nil { - fmt.Println(err) - continue + fmt.Fprintln(os.Stderr, err) } - if fr.Path == hashedpath { - defer ch.Close() + + go func() { f, err := os.Open(path) if err != nil { fmt.Fprintln(ch, "Sharethis error") - os.Exit(1) + return } - io.Copy(ch, f) - return - } + _, err = io.Copy(ch, f) + if err != nil { + fmt.Println(err) + return + } + mu.Lock() + if *sharecount == 0 { + mu.Unlock() + os.Exit(0) + } + *sharecount-- + mu.Unlock() + ch.Close() + }() + } } @@ -159,18 +181,18 @@ func SSHAgent() ssh.AuthMethod { } func runServer(host, sshport, httpport, keyfile string) { - filemap := map[string]*ssh.Channel{} + filemap := map[string]*FileReq{} syncy := &sync.RWMutex{} - mapget := func(key string) (*ssh.Channel, bool) { + mapget := func(key string) (*FileReq, bool) { syncy.RLock() defer syncy.RUnlock() c, ok := filemap[key] return c, ok } - mapset := func(key string, channel *ssh.Channel) { + mapset := func(key string, fr *FileReq) { syncy.Lock() - filemap[key] = channel + filemap[key] = fr syncy.Unlock() } mapdel := func(key string) { @@ -235,7 +257,10 @@ func runServer(host, sshport, httpport, keyfile string) { if err != nil { continue } - mapset(filereq.Path, &channel) + + // TODO: also take the sharecount in the map. + filereq.serverconn = serverConn + mapset(filereq.Path, &filereq) go func() { serverConn.Wait(); mapdel(filereq.Path) }() return } @@ -245,18 +270,29 @@ func runServer(host, sshport, httpport, keyfile string) { } }() http.ListenAndServe(fmt.Sprintf(":%s", httpport), http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + if strings.Contains(req.UserAgent(), "Slackbot-LinkExpanding") { + rw.WriteHeader(http.StatusForbidden) + fmt.Fprintln(rw, "No slackbots allowed!") + return + } fp := strings.TrimPrefix(req.URL.Path, "/") - if channel, ok := mapget(fp); ok { - defer func() { (*channel).Close() }() - enc := gob.NewEncoder(*channel) - err := enc.Encode(&FileReq{fp}) + if fr, ok := mapget(fp); ok { + channel, req, err := fr.serverconn.OpenChannel(fp, nil) if err != nil { - fmt.Fprintln(rw, err) + fmt.Fprintln(os.Stderr, err) return } - io.Copy(rw, *channel) - mapdel(fp) + go ssh.DiscardRequests(req) + _, err = io.Copy(rw, channel) + if err != nil { + fmt.Println(err) + } + if fr.ShareCount == 0 { + mapdel(fp) + channel.Close() + } + fr.ShareCount-- return } else { rw.WriteHeader(http.StatusNotFound) @@ -273,14 +309,15 @@ func buildCfg() *ssh.ServerConfig { log.Fatalf("Failed to load authorized_keys, err: %v", err) } - authorizedKeysMap := map[string]bool{} + authorizedKeysMap := map[string]string{} for len(authorizedKeysBytes) > 0 { - pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes) + pubKey, comment, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes) + fmt.Println(comment) if err != nil { log.Fatal(err) } - authorizedKeysMap[string(pubKey.Marshal())] = true + authorizedKeysMap[string(pubKey.Marshal())] = comment authorizedKeysBytes = rest } @@ -289,7 +326,7 @@ func buildCfg() *ssh.ServerConfig { cfg.SetDefaults() cfg.PasswordCallback = func(ssh.ConnMetadata, []byte) (*ssh.Permissions, error) { return nil, fmt.Errorf("Public key only") } cfg.PublicKeyCallback = func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { - if authorizedKeysMap[string(key.Marshal())] { + if _, ok := authorizedKeysMap[string(key.Marshal())]; ok { return nil, nil } return nil, fmt.Errorf("unknown public key for %q", conn.User())