Extract some stuff out
This commit is contained in:
parent
8675e6815a
commit
46286cb851
3 changed files with 199 additions and 114 deletions
8
channel.go
Normal file
8
channel.go
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
package sharethis
|
||||
|
||||
import "golang.org/x/crypto/ssh"
|
||||
|
||||
type NewChannel struct {
|
||||
ssh.NewChannel
|
||||
DoZip bool
|
||||
}
|
||||
|
|
@ -1,8 +1,8 @@
|
|||
package main // import "github.com/otremblay/sharethis"
|
||||
package main // import "github.com/otremblay/sharethis/cmd/sharethis"
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"crypto/sha256"
|
||||
"archive/zip"
|
||||
"encoding/gob"
|
||||
"flag"
|
||||
"fmt"
|
||||
|
|
@ -13,26 +13,19 @@ import (
|
|||
"net/http"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"github.com/otremblay/sharethis"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/crypto/ssh/agent"
|
||||
)
|
||||
|
||||
type FileReq struct {
|
||||
Path string
|
||||
ShareCount uint
|
||||
serverconn *ssh.ServerConn
|
||||
}
|
||||
|
||||
var authorizedKeys = os.Getenv("SHARETHIS_AUTHORIZEDKEYS")
|
||||
|
||||
func main() {
|
||||
bg := flag.Bool("bg", false, "sends the process in the background")
|
||||
server := flag.Bool("server", false, "makes the process an http server")
|
||||
// admin := flag.String("admin", "", "determine admin user")
|
||||
remotehost := flag.String("remote", "share.otremblay.com", "remote server for sharethis to contact")
|
||||
sshport := flag.String("sshport", "2022", "the remote ssh port")
|
||||
httpport := flag.String("httpport", "8888", "the remote server's http port")
|
||||
|
|
@ -63,28 +56,8 @@ func main() {
|
|||
if len(flag.Args()) < 1 {
|
||||
log.Fatalln("Need filename")
|
||||
}
|
||||
var path string
|
||||
var dir bool
|
||||
var name string
|
||||
if fs, err := os.Stat(flag.Arg(0)); err != nil {
|
||||
log.Fatalln("Can't read file")
|
||||
} else {
|
||||
name = fs.Name()
|
||||
p, err := filepath.Abs(fs.Name())
|
||||
if err != nil {
|
||||
log.Fatalln("Can't read file")
|
||||
}
|
||||
path = p
|
||||
dir = fs.IsDir()
|
||||
}
|
||||
path := flag.Arg(0)
|
||||
|
||||
if *bg {
|
||||
_, err := syscall.ForkExec(os.Args[0], append([]string{os.Args[0]}, flag.Args()...), &syscall.ProcAttr{Files: []uintptr{0, 1, 2}})
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
os.Exit(0)
|
||||
}
|
||||
var username string
|
||||
userobj, err := user.Current()
|
||||
if err != nil {
|
||||
|
|
@ -124,36 +97,10 @@ func main() {
|
|||
log.Fatalln("poop", err)
|
||||
}
|
||||
enc := gob.NewEncoder(ch)
|
||||
path = flag.Arg(0)
|
||||
fr := sharethis.NewFileReq(path, *httpport, *remotehost, *sharecount)
|
||||
fmt.Println(fr)
|
||||
|
||||
hostname, err := os.Hostname()
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "Could not get hostname with os.Hostname()")
|
||||
hostname = "unknown"
|
||||
}
|
||||
fullpath := fmt.Sprintf("%s@%s:%s", username, hostname, path)
|
||||
hashedpath := fmt.Sprintf("%x", sha256.Sum256([]byte(fullpath)))
|
||||
|
||||
// In the words of weezer, I've got my hashed path.
|
||||
// TODO: Get the remote URL from the remote server instead of rebuilding it locally.
|
||||
var fileurl string
|
||||
fullpath = fmt.Sprintf("%s/%s", hashedpath, name)
|
||||
if *httpport == "443" {
|
||||
fileurl = fmt.Sprintf("https://%s/%s", *remotehost, fullpath)
|
||||
} else if *httpport == "80" {
|
||||
fileurl = fmt.Sprintf("http://%s/%s", *remotehost, fullpath)
|
||||
} else {
|
||||
fileurl = fmt.Sprintf("http://%s:%s/%s", *remotehost, *httpport, fullpath)
|
||||
}
|
||||
|
||||
if dir {
|
||||
fmt.Println(fileurl + ".tar")
|
||||
fmt.Println(fileurl + ".zip")
|
||||
} else {
|
||||
fmt.Println(fileurl)
|
||||
}
|
||||
|
||||
err = enc.Encode(&FileReq{Path: fullpath, ShareCount: *sharecount})
|
||||
err = enc.Encode(fr)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
ch.Close()
|
||||
|
|
@ -163,23 +110,23 @@ func main() {
|
|||
mu := &sync.Mutex{}
|
||||
|
||||
defer ch.Close()
|
||||
var getchfn func() ssh.NewChannel
|
||||
if dir {
|
||||
ncc := connection.HandleChannelOpen(fullpath + ".tar")
|
||||
ncc2 := connection.HandleChannelOpen(fullpath + ".zip")
|
||||
getchfn = func() ssh.NewChannel {
|
||||
var getchfn func() *sharethis.NewChannel
|
||||
if fr.IsDir {
|
||||
ncc := connection.HandleChannelOpen(fr.Path + ".tar")
|
||||
ncc2 := connection.HandleChannelOpen(fr.Path + ".zip")
|
||||
getchfn = func() *sharethis.NewChannel {
|
||||
select {
|
||||
case nc := <-ncc:
|
||||
return nc
|
||||
return &sharethis.NewChannel{NewChannel: nc}
|
||||
case nc := <-ncc2:
|
||||
return nc
|
||||
return &sharethis.NewChannel{NewChannel: nc, DoZip: true}
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
ncc := connection.HandleChannelOpen(fullpath)
|
||||
getchfn = func() ssh.NewChannel {
|
||||
return <-ncc
|
||||
ncc := connection.HandleChannelOpen(fr.Path)
|
||||
getchfn = func() *sharethis.NewChannel {
|
||||
return &sharethis.NewChannel{<-ncc, false}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -191,7 +138,13 @@ func main() {
|
|||
fmt.Fprintln(os.Stderr, err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
go handleFileRequest(path, ch, mu, sharecount, nc.DoZip)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func handleFileRequest(path string, ch ssh.Channel, mu *sync.Mutex, sharecount *uint, dozip bool) {
|
||||
|
||||
f, err := os.Open(path)
|
||||
|
||||
if err != nil {
|
||||
|
|
@ -199,8 +152,54 @@ func main() {
|
|||
return
|
||||
}
|
||||
if fs, _ := f.Stat(); fs.IsDir() {
|
||||
if dozip {
|
||||
w := zip.NewWriter(ch)
|
||||
zipfn := func(fi os.FileInfo, path string, file *os.File) error {
|
||||
hdr, err := zip.FileInfoHeader(fi)
|
||||
hdr.Name = path
|
||||
if err != nil {
|
||||
return fmt.Errorf("Couldn't build zip header: %v", err)
|
||||
}
|
||||
f, err := w.CreateHeader(hdr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Couldn't write zip header: %v", err)
|
||||
}
|
||||
_, err = io.Copy(f, file)
|
||||
return err
|
||||
}
|
||||
err := WriteFiles(zipfn, fs, ".")
|
||||
if err != nil {
|
||||
fmt.Fprintln(ch, "Error building tar", err)
|
||||
}
|
||||
w.Close()
|
||||
mu.Lock()
|
||||
if *sharecount == 0 {
|
||||
mu.Unlock()
|
||||
os.Exit(0)
|
||||
}
|
||||
*sharecount--
|
||||
mu.Unlock()
|
||||
ch.Close()
|
||||
|
||||
return
|
||||
|
||||
} else {
|
||||
tw := tar.NewWriter(ch)
|
||||
err := WriteFiles(tw, fs, ".")
|
||||
tarfn := func(fi os.FileInfo, path string, file *os.File) error {
|
||||
hdr, err := tar.FileInfoHeader(fi, "")
|
||||
hdr.Name = path
|
||||
if err != nil {
|
||||
return fmt.Errorf("Couldn't build tar header: %v", err)
|
||||
}
|
||||
|
||||
err = tw.WriteHeader(hdr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Couldn't write tar header: %v", err)
|
||||
}
|
||||
_, err = io.Copy(tw, file)
|
||||
return err
|
||||
}
|
||||
err := WriteFiles(tarfn, fs, ".")
|
||||
if err != nil {
|
||||
fmt.Fprintln(ch, "Error building tar", err)
|
||||
}
|
||||
|
|
@ -216,6 +215,7 @@ func main() {
|
|||
|
||||
return
|
||||
}
|
||||
}
|
||||
_, err = io.Copy(ch, f)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
|
|
@ -229,12 +229,10 @@ func main() {
|
|||
*sharecount--
|
||||
mu.Unlock()
|
||||
ch.Close()
|
||||
}()
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func WriteFiles(tw *tar.Writer, fi os.FileInfo, path string) error {
|
||||
func WriteFiles(writerfn func(os.FileInfo, string, *os.File) error, fi os.FileInfo, path string) error {
|
||||
p := path + "/" + fi.Name()
|
||||
f, err := os.Open(p)
|
||||
if err != nil {
|
||||
|
|
@ -247,25 +245,15 @@ func WriteFiles(tw *tar.Writer, fi os.FileInfo, path string) error {
|
|||
}
|
||||
var cumulerror error
|
||||
for _, nfi := range fis {
|
||||
err := WriteFiles(tw, nfi, p)
|
||||
err := WriteFiles(writerfn, nfi, p)
|
||||
if err != nil {
|
||||
cumulerror = fmt.Errorf("%vCouldn't write files for %s: %v\n", cumulerror, nfi.Name(), err)
|
||||
}
|
||||
}
|
||||
return cumulerror
|
||||
}
|
||||
hdr, err := tar.FileInfoHeader(fi, "")
|
||||
hdr.Name = p
|
||||
if err != nil {
|
||||
return fmt.Errorf("Couldn't build tar header: %v", err)
|
||||
}
|
||||
|
||||
err = tw.WriteHeader(hdr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Couldn't write tar header: %v", err)
|
||||
}
|
||||
_, err = io.Copy(tw, f)
|
||||
return err
|
||||
return writerfn(fi, p, f)
|
||||
}
|
||||
|
||||
func PublicKeyFile(file string) (ssh.AuthMethod, error) {
|
||||
|
|
@ -297,16 +285,16 @@ func SSHAgent() (ssh.AuthMethod, error) {
|
|||
}
|
||||
|
||||
func runServer(host, sshport, httpport, keyfile string) {
|
||||
filemap := map[string]*FileReq{}
|
||||
filemap := map[string]*sharethis.FileReq{}
|
||||
syncy := &sync.RWMutex{}
|
||||
|
||||
mapget := func(key string) (*FileReq, bool) {
|
||||
mapget := func(key string) (*sharethis.FileReq, bool) {
|
||||
syncy.RLock()
|
||||
defer syncy.RUnlock()
|
||||
c, ok := filemap[key]
|
||||
return c, ok
|
||||
}
|
||||
mapset := func(key string, fr *FileReq) {
|
||||
mapset := func(key string, fr *sharethis.FileReq) {
|
||||
syncy.Lock()
|
||||
filemap[key] = fr
|
||||
syncy.Unlock()
|
||||
|
|
@ -369,7 +357,7 @@ func runServer(host, sshport, httpport, keyfile string) {
|
|||
|
||||
go func() {
|
||||
dec := gob.NewDecoder(channel)
|
||||
var filereq FileReq
|
||||
var filereq sharethis.FileReq
|
||||
|
||||
for {
|
||||
err := dec.Decode(&filereq)
|
||||
|
|
@ -378,7 +366,7 @@ func runServer(host, sshport, httpport, keyfile string) {
|
|||
}
|
||||
|
||||
// TODO: also take the sharecount in the map.
|
||||
filereq.serverconn = serverConn
|
||||
filereq.ServerConn = serverConn
|
||||
mapset(filereq.Path, &filereq)
|
||||
go func() { serverConn.Wait(); mapdel(filereq.Path) }()
|
||||
return
|
||||
|
|
@ -406,7 +394,7 @@ func runServer(host, sshport, httpport, keyfile string) {
|
|||
}
|
||||
fp = strings.TrimSuffix(fp, ".tar")
|
||||
if fr, ok := mapget(fp); ok {
|
||||
channel, req, err := fr.serverconn.OpenChannel(fp+suffix, nil)
|
||||
channel, req, err := fr.ServerConn.OpenChannel(fp+suffix, nil)
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
return
|
||||
89
filereq.go
Normal file
89
filereq.go
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
package sharethis // import "github.com/otremblay/sharethis"
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type FileReq struct {
|
||||
Path string
|
||||
ShareCount uint
|
||||
Username string
|
||||
Hostname string
|
||||
ServerConn *ssh.ServerConn
|
||||
fileurl string
|
||||
httpport string
|
||||
remotehost string
|
||||
localpath string
|
||||
filename string
|
||||
IsDir bool
|
||||
}
|
||||
|
||||
func NewFileReq(path, httpport, remotehost string, shareCount uint) *FileReq {
|
||||
return getFileReq(DefaultPrefixFn, path, httpport, remotehost, shareCount)
|
||||
}
|
||||
|
||||
func (fr *FileReq) RebuildUrls() {
|
||||
fr = getFileReq(DefaultPrefixFn, fr.localpath, fr.httpport, fr.remotehost, fr.ShareCount)
|
||||
}
|
||||
|
||||
func (fr *FileReq) String() string {
|
||||
if fr.IsDir {
|
||||
return fmt.Sprintf("%s\n%s", fr.fileurl+".tar", fr.fileurl+".zip")
|
||||
} else {
|
||||
return fmt.Sprintf(fr.fileurl)
|
||||
}
|
||||
}
|
||||
|
||||
func getRemotePath(prefixFn func() string, path string) string {
|
||||
fullpath := fmt.Sprintf("%s:%s", prefixFn(), path)
|
||||
hashedpath := fmt.Sprintf("%x", sha256.Sum256([]byte(fullpath)))
|
||||
return fmt.Sprintf("%s/%s", hashedpath, filepath.Base(path))
|
||||
}
|
||||
|
||||
func getFileReq(prefixFn func() string, path string, httpport string, remotehost string, sharecount uint) *FileReq {
|
||||
var dir bool
|
||||
if fs, err := os.Stat(path); err != nil {
|
||||
log.Fatalln("Can't read file")
|
||||
} else {
|
||||
p, err := filepath.Abs(fs.Name())
|
||||
if err != nil {
|
||||
log.Fatalln("Can't read file")
|
||||
}
|
||||
dir = fs.IsDir()
|
||||
path = p
|
||||
}
|
||||
remotepath := getRemotePath(prefixFn, path)
|
||||
var fileurl string
|
||||
if httpport == "443" {
|
||||
fileurl = fmt.Sprintf("https://%s/%s", remotehost, remotepath)
|
||||
} else if httpport == "80" {
|
||||
fileurl = fmt.Sprintf("http://%s/%s", remotehost, remotepath)
|
||||
} else {
|
||||
fileurl = fmt.Sprintf("http://%s:%s/%s", remotehost, httpport, remotepath)
|
||||
}
|
||||
return &FileReq{Path: remotepath, ShareCount: sharecount, fileurl: fileurl, localpath: path, IsDir: dir}
|
||||
}
|
||||
|
||||
var DefaultPrefixFn = func() string {
|
||||
hostname, err := os.Hostname()
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "Could not get hostname with os.Hostname()")
|
||||
hostname = "unknown"
|
||||
}
|
||||
var username string
|
||||
userobj, err := user.Current()
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "Could not get user with user.Current()")
|
||||
username = "unknown"
|
||||
} else {
|
||||
username = userobj.Username
|
||||
}
|
||||
return fmt.Sprintf("%s@%s", username, hostname)
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue