Support multiple downloads
This commit is contained in:
parent
03442ed63f
commit
2e8083e42f
2 changed files with 81 additions and 35 deletions
107
main.go
107
main.go
|
|
@ -22,12 +22,9 @@ import (
|
|||
)
|
||||
|
||||
type FileReq struct {
|
||||
Path string
|
||||
}
|
||||
|
||||
type chanpair struct {
|
||||
sshchan *ssh.Channel
|
||||
sigchan chan struct{}
|
||||
Path string
|
||||
ShareCount uint
|
||||
serverconn *ssh.ServerConn
|
||||
}
|
||||
|
||||
func main() {
|
||||
|
|
@ -36,6 +33,7 @@ func main() {
|
|||
remotehost := flag.String("remote", "share.otremblay.com", "remote server for sharethis to contact")
|
||||
sshport := flag.String("sshport", "2022", "the remote ssh port")
|
||||
httpport := flag.String("httpport", "8888", "the remote server's http port")
|
||||
sharecount := flag.Uint("count", 1, "Amount of times you want to share this file")
|
||||
flag.Parse()
|
||||
if *server {
|
||||
runServer("0.0.0.0", *sshport, *httpport, "id_rsa")
|
||||
|
|
@ -107,31 +105,55 @@ func main() {
|
|||
// In the words of weezer, I've got my hashed path.
|
||||
// TODO: Get the remote URL from the remote server instead of rebuilding it locally.
|
||||
// TODO: Clean up the port from the URL if it's 80 or 443
|
||||
fmt.Println(fmt.Sprintf("http://%s:%s/%s", *remotehost, *httpport, hashedpath))
|
||||
err = enc.Encode(&FileReq{hashedpath})
|
||||
var fileurl string
|
||||
if *httpport == "443" {
|
||||
fileurl = fmt.Sprintf("https://%s/%s", *remotehost, hashedpath)
|
||||
} else if *httpport == "80" {
|
||||
fileurl = fmt.Sprintf("http://%s/%s", *remotehost, hashedpath)
|
||||
} else {
|
||||
fileurl = fmt.Sprintf("http://%s:%s/%s", *remotehost, *httpport, hashedpath)
|
||||
}
|
||||
fmt.Println(fileurl)
|
||||
err = enc.Encode(&FileReq{Path: hashedpath, ShareCount: *sharecount})
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
ch.Close()
|
||||
os.Exit(1)
|
||||
}
|
||||
dec := gob.NewDecoder(ch)
|
||||
var fr *FileReq
|
||||
|
||||
mu := &sync.Mutex{}
|
||||
|
||||
defer ch.Close()
|
||||
ncc := connection.HandleChannelOpen(hashedpath)
|
||||
for {
|
||||
err := dec.Decode(&fr)
|
||||
nc := <-ncc
|
||||
ch, req, err := nc.Accept()
|
||||
go ssh.DiscardRequests(req)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
continue
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
}
|
||||
if fr.Path == hashedpath {
|
||||
defer ch.Close()
|
||||
|
||||
go func() {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
fmt.Fprintln(ch, "Sharethis error")
|
||||
os.Exit(1)
|
||||
return
|
||||
}
|
||||
io.Copy(ch, f)
|
||||
return
|
||||
}
|
||||
_, err = io.Copy(ch, f)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
mu.Lock()
|
||||
if *sharecount == 0 {
|
||||
mu.Unlock()
|
||||
os.Exit(0)
|
||||
}
|
||||
*sharecount--
|
||||
mu.Unlock()
|
||||
ch.Close()
|
||||
}()
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -159,18 +181,18 @@ func SSHAgent() ssh.AuthMethod {
|
|||
}
|
||||
|
||||
func runServer(host, sshport, httpport, keyfile string) {
|
||||
filemap := map[string]*ssh.Channel{}
|
||||
filemap := map[string]*FileReq{}
|
||||
syncy := &sync.RWMutex{}
|
||||
|
||||
mapget := func(key string) (*ssh.Channel, bool) {
|
||||
mapget := func(key string) (*FileReq, bool) {
|
||||
syncy.RLock()
|
||||
defer syncy.RUnlock()
|
||||
c, ok := filemap[key]
|
||||
return c, ok
|
||||
}
|
||||
mapset := func(key string, channel *ssh.Channel) {
|
||||
mapset := func(key string, fr *FileReq) {
|
||||
syncy.Lock()
|
||||
filemap[key] = channel
|
||||
filemap[key] = fr
|
||||
syncy.Unlock()
|
||||
}
|
||||
mapdel := func(key string) {
|
||||
|
|
@ -235,7 +257,10 @@ func runServer(host, sshport, httpport, keyfile string) {
|
|||
if err != nil {
|
||||
continue
|
||||
}
|
||||
mapset(filereq.Path, &channel)
|
||||
|
||||
// TODO: also take the sharecount in the map.
|
||||
filereq.serverconn = serverConn
|
||||
mapset(filereq.Path, &filereq)
|
||||
go func() { serverConn.Wait(); mapdel(filereq.Path) }()
|
||||
return
|
||||
}
|
||||
|
|
@ -245,18 +270,29 @@ func runServer(host, sshport, httpport, keyfile string) {
|
|||
}
|
||||
}()
|
||||
http.ListenAndServe(fmt.Sprintf(":%s", httpport), http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
if strings.Contains(req.UserAgent(), "Slackbot-LinkExpanding") {
|
||||
rw.WriteHeader(http.StatusForbidden)
|
||||
fmt.Fprintln(rw, "No slackbots allowed!")
|
||||
return
|
||||
}
|
||||
fp := strings.TrimPrefix(req.URL.Path, "/")
|
||||
if channel, ok := mapget(fp); ok {
|
||||
defer func() { (*channel).Close() }()
|
||||
enc := gob.NewEncoder(*channel)
|
||||
err := enc.Encode(&FileReq{fp})
|
||||
if fr, ok := mapget(fp); ok {
|
||||
channel, req, err := fr.serverconn.OpenChannel(fp, nil)
|
||||
if err != nil {
|
||||
fmt.Fprintln(rw, err)
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
return
|
||||
}
|
||||
io.Copy(rw, *channel)
|
||||
mapdel(fp)
|
||||
go ssh.DiscardRequests(req)
|
||||
|
||||
_, err = io.Copy(rw, channel)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
if fr.ShareCount == 0 {
|
||||
mapdel(fp)
|
||||
channel.Close()
|
||||
}
|
||||
fr.ShareCount--
|
||||
return
|
||||
} else {
|
||||
rw.WriteHeader(http.StatusNotFound)
|
||||
|
|
@ -273,14 +309,15 @@ func buildCfg() *ssh.ServerConfig {
|
|||
log.Fatalf("Failed to load authorized_keys, err: %v", err)
|
||||
}
|
||||
|
||||
authorizedKeysMap := map[string]bool{}
|
||||
authorizedKeysMap := map[string]string{}
|
||||
for len(authorizedKeysBytes) > 0 {
|
||||
pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes)
|
||||
pubKey, comment, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes)
|
||||
fmt.Println(comment)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
authorizedKeysMap[string(pubKey.Marshal())] = true
|
||||
authorizedKeysMap[string(pubKey.Marshal())] = comment
|
||||
authorizedKeysBytes = rest
|
||||
}
|
||||
|
||||
|
|
@ -289,7 +326,7 @@ func buildCfg() *ssh.ServerConfig {
|
|||
cfg.SetDefaults()
|
||||
cfg.PasswordCallback = func(ssh.ConnMetadata, []byte) (*ssh.Permissions, error) { return nil, fmt.Errorf("Public key only") }
|
||||
cfg.PublicKeyCallback = func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
|
||||
if authorizedKeysMap[string(key.Marshal())] {
|
||||
if _, ok := authorizedKeysMap[string(key.Marshal())]; ok {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("unknown public key for %q", conn.User())
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue