Support multiple downloads
This commit is contained in:
parent
03442ed63f
commit
2e8083e42f
2 changed files with 81 additions and 35 deletions
|
|
@ -1,2 +1,11 @@
|
||||||
# sharethis
|
# sharethis
|
||||||
Easy file sharing for the mere mortal
|
Easy file sharing for the mere mortal
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
|
||||||
|
- Better documentation. :P
|
||||||
|
- Logging (probably with log levels too)
|
||||||
|
- Per-account bandwidth quotas
|
||||||
|
- Admin functionality (use the client to add or remove keys, adjust quotas, etc)
|
||||||
|
- Remove hardcoded idiotic values
|
||||||
|
- UPnP hole-punching (so that the server does not have to proxy files (as in direct peer-to-peer file sending))
|
||||||
107
main.go
107
main.go
|
|
@ -22,12 +22,9 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type FileReq struct {
|
type FileReq struct {
|
||||||
Path string
|
Path string
|
||||||
}
|
ShareCount uint
|
||||||
|
serverconn *ssh.ServerConn
|
||||||
type chanpair struct {
|
|
||||||
sshchan *ssh.Channel
|
|
||||||
sigchan chan struct{}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
|
@ -36,6 +33,7 @@ func main() {
|
||||||
remotehost := flag.String("remote", "share.otremblay.com", "remote server for sharethis to contact")
|
remotehost := flag.String("remote", "share.otremblay.com", "remote server for sharethis to contact")
|
||||||
sshport := flag.String("sshport", "2022", "the remote ssh port")
|
sshport := flag.String("sshport", "2022", "the remote ssh port")
|
||||||
httpport := flag.String("httpport", "8888", "the remote server's http port")
|
httpport := flag.String("httpport", "8888", "the remote server's http port")
|
||||||
|
sharecount := flag.Uint("count", 1, "Amount of times you want to share this file")
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
if *server {
|
if *server {
|
||||||
runServer("0.0.0.0", *sshport, *httpport, "id_rsa")
|
runServer("0.0.0.0", *sshport, *httpport, "id_rsa")
|
||||||
|
|
@ -107,31 +105,55 @@ func main() {
|
||||||
// In the words of weezer, I've got my hashed path.
|
// In the words of weezer, I've got my hashed path.
|
||||||
// TODO: Get the remote URL from the remote server instead of rebuilding it locally.
|
// TODO: Get the remote URL from the remote server instead of rebuilding it locally.
|
||||||
// TODO: Clean up the port from the URL if it's 80 or 443
|
// TODO: Clean up the port from the URL if it's 80 or 443
|
||||||
fmt.Println(fmt.Sprintf("http://%s:%s/%s", *remotehost, *httpport, hashedpath))
|
var fileurl string
|
||||||
err = enc.Encode(&FileReq{hashedpath})
|
if *httpport == "443" {
|
||||||
|
fileurl = fmt.Sprintf("https://%s/%s", *remotehost, hashedpath)
|
||||||
|
} else if *httpport == "80" {
|
||||||
|
fileurl = fmt.Sprintf("http://%s/%s", *remotehost, hashedpath)
|
||||||
|
} else {
|
||||||
|
fileurl = fmt.Sprintf("http://%s:%s/%s", *remotehost, *httpport, hashedpath)
|
||||||
|
}
|
||||||
|
fmt.Println(fileurl)
|
||||||
|
err = enc.Encode(&FileReq{Path: hashedpath, ShareCount: *sharecount})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
ch.Close()
|
ch.Close()
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
dec := gob.NewDecoder(ch)
|
|
||||||
var fr *FileReq
|
mu := &sync.Mutex{}
|
||||||
|
|
||||||
|
defer ch.Close()
|
||||||
|
ncc := connection.HandleChannelOpen(hashedpath)
|
||||||
for {
|
for {
|
||||||
err := dec.Decode(&fr)
|
nc := <-ncc
|
||||||
|
ch, req, err := nc.Accept()
|
||||||
|
go ssh.DiscardRequests(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println(err)
|
fmt.Fprintln(os.Stderr, err)
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
if fr.Path == hashedpath {
|
|
||||||
defer ch.Close()
|
go func() {
|
||||||
f, err := os.Open(path)
|
f, err := os.Open(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Fprintln(ch, "Sharethis error")
|
fmt.Fprintln(ch, "Sharethis error")
|
||||||
os.Exit(1)
|
return
|
||||||
}
|
}
|
||||||
io.Copy(ch, f)
|
_, err = io.Copy(ch, f)
|
||||||
return
|
if err != nil {
|
||||||
}
|
fmt.Println(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
mu.Lock()
|
||||||
|
if *sharecount == 0 {
|
||||||
|
mu.Unlock()
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
*sharecount--
|
||||||
|
mu.Unlock()
|
||||||
|
ch.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -159,18 +181,18 @@ func SSHAgent() ssh.AuthMethod {
|
||||||
}
|
}
|
||||||
|
|
||||||
func runServer(host, sshport, httpport, keyfile string) {
|
func runServer(host, sshport, httpport, keyfile string) {
|
||||||
filemap := map[string]*ssh.Channel{}
|
filemap := map[string]*FileReq{}
|
||||||
syncy := &sync.RWMutex{}
|
syncy := &sync.RWMutex{}
|
||||||
|
|
||||||
mapget := func(key string) (*ssh.Channel, bool) {
|
mapget := func(key string) (*FileReq, bool) {
|
||||||
syncy.RLock()
|
syncy.RLock()
|
||||||
defer syncy.RUnlock()
|
defer syncy.RUnlock()
|
||||||
c, ok := filemap[key]
|
c, ok := filemap[key]
|
||||||
return c, ok
|
return c, ok
|
||||||
}
|
}
|
||||||
mapset := func(key string, channel *ssh.Channel) {
|
mapset := func(key string, fr *FileReq) {
|
||||||
syncy.Lock()
|
syncy.Lock()
|
||||||
filemap[key] = channel
|
filemap[key] = fr
|
||||||
syncy.Unlock()
|
syncy.Unlock()
|
||||||
}
|
}
|
||||||
mapdel := func(key string) {
|
mapdel := func(key string) {
|
||||||
|
|
@ -235,7 +257,10 @@ func runServer(host, sshport, httpport, keyfile string) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
mapset(filereq.Path, &channel)
|
|
||||||
|
// TODO: also take the sharecount in the map.
|
||||||
|
filereq.serverconn = serverConn
|
||||||
|
mapset(filereq.Path, &filereq)
|
||||||
go func() { serverConn.Wait(); mapdel(filereq.Path) }()
|
go func() { serverConn.Wait(); mapdel(filereq.Path) }()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
@ -245,18 +270,29 @@ func runServer(host, sshport, httpport, keyfile string) {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
http.ListenAndServe(fmt.Sprintf(":%s", httpport), http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
http.ListenAndServe(fmt.Sprintf(":%s", httpport), http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||||
|
if strings.Contains(req.UserAgent(), "Slackbot-LinkExpanding") {
|
||||||
|
rw.WriteHeader(http.StatusForbidden)
|
||||||
|
fmt.Fprintln(rw, "No slackbots allowed!")
|
||||||
|
return
|
||||||
|
}
|
||||||
fp := strings.TrimPrefix(req.URL.Path, "/")
|
fp := strings.TrimPrefix(req.URL.Path, "/")
|
||||||
if channel, ok := mapget(fp); ok {
|
if fr, ok := mapget(fp); ok {
|
||||||
defer func() { (*channel).Close() }()
|
channel, req, err := fr.serverconn.OpenChannel(fp, nil)
|
||||||
enc := gob.NewEncoder(*channel)
|
|
||||||
err := enc.Encode(&FileReq{fp})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Fprintln(rw, err)
|
fmt.Fprintln(os.Stderr, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
io.Copy(rw, *channel)
|
go ssh.DiscardRequests(req)
|
||||||
mapdel(fp)
|
|
||||||
|
|
||||||
|
_, err = io.Copy(rw, channel)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
|
}
|
||||||
|
if fr.ShareCount == 0 {
|
||||||
|
mapdel(fp)
|
||||||
|
channel.Close()
|
||||||
|
}
|
||||||
|
fr.ShareCount--
|
||||||
return
|
return
|
||||||
} else {
|
} else {
|
||||||
rw.WriteHeader(http.StatusNotFound)
|
rw.WriteHeader(http.StatusNotFound)
|
||||||
|
|
@ -273,14 +309,15 @@ func buildCfg() *ssh.ServerConfig {
|
||||||
log.Fatalf("Failed to load authorized_keys, err: %v", err)
|
log.Fatalf("Failed to load authorized_keys, err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
authorizedKeysMap := map[string]bool{}
|
authorizedKeysMap := map[string]string{}
|
||||||
for len(authorizedKeysBytes) > 0 {
|
for len(authorizedKeysBytes) > 0 {
|
||||||
pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes)
|
pubKey, comment, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes)
|
||||||
|
fmt.Println(comment)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
authorizedKeysMap[string(pubKey.Marshal())] = true
|
authorizedKeysMap[string(pubKey.Marshal())] = comment
|
||||||
authorizedKeysBytes = rest
|
authorizedKeysBytes = rest
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -289,7 +326,7 @@ func buildCfg() *ssh.ServerConfig {
|
||||||
cfg.SetDefaults()
|
cfg.SetDefaults()
|
||||||
cfg.PasswordCallback = func(ssh.ConnMetadata, []byte) (*ssh.Permissions, error) { return nil, fmt.Errorf("Public key only") }
|
cfg.PasswordCallback = func(ssh.ConnMetadata, []byte) (*ssh.Permissions, error) { return nil, fmt.Errorf("Public key only") }
|
||||||
cfg.PublicKeyCallback = func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
|
cfg.PublicKeyCallback = func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
|
||||||
if authorizedKeysMap[string(key.Marshal())] {
|
if _, ok := authorizedKeysMap[string(key.Marshal())]; ok {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("unknown public key for %q", conn.User())
|
return nil, fmt.Errorf("unknown public key for %q", conn.User())
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue