Extract some stuff out
This commit is contained in:
parent
8675e6815a
commit
46286cb851
3 changed files with 199 additions and 114 deletions
8
channel.go
Normal file
8
channel.go
Normal file
|
|
@ -0,0 +1,8 @@
|
||||||
|
package sharethis
|
||||||
|
|
||||||
|
import "golang.org/x/crypto/ssh"
|
||||||
|
|
||||||
|
type NewChannel struct {
|
||||||
|
ssh.NewChannel
|
||||||
|
DoZip bool
|
||||||
|
}
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
package main // import "github.com/otremblay/sharethis"
|
package main // import "github.com/otremblay/sharethis/cmd/sharethis"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"archive/tar"
|
"archive/tar"
|
||||||
"crypto/sha256"
|
"archive/zip"
|
||||||
"encoding/gob"
|
"encoding/gob"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
@ -13,26 +13,19 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"os/user"
|
"os/user"
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"syscall"
|
|
||||||
|
|
||||||
|
"github.com/otremblay/sharethis"
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
"golang.org/x/crypto/ssh/agent"
|
"golang.org/x/crypto/ssh/agent"
|
||||||
)
|
)
|
||||||
|
|
||||||
type FileReq struct {
|
|
||||||
Path string
|
|
||||||
ShareCount uint
|
|
||||||
serverconn *ssh.ServerConn
|
|
||||||
}
|
|
||||||
|
|
||||||
var authorizedKeys = os.Getenv("SHARETHIS_AUTHORIZEDKEYS")
|
var authorizedKeys = os.Getenv("SHARETHIS_AUTHORIZEDKEYS")
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
bg := flag.Bool("bg", false, "sends the process in the background")
|
|
||||||
server := flag.Bool("server", false, "makes the process an http server")
|
server := flag.Bool("server", false, "makes the process an http server")
|
||||||
|
// admin := flag.String("admin", "", "determine admin user")
|
||||||
remotehost := flag.String("remote", "share.otremblay.com", "remote server for sharethis to contact")
|
remotehost := flag.String("remote", "share.otremblay.com", "remote server for sharethis to contact")
|
||||||
sshport := flag.String("sshport", "2022", "the remote ssh port")
|
sshport := flag.String("sshport", "2022", "the remote ssh port")
|
||||||
httpport := flag.String("httpport", "8888", "the remote server's http port")
|
httpport := flag.String("httpport", "8888", "the remote server's http port")
|
||||||
|
|
@ -63,28 +56,8 @@ func main() {
|
||||||
if len(flag.Args()) < 1 {
|
if len(flag.Args()) < 1 {
|
||||||
log.Fatalln("Need filename")
|
log.Fatalln("Need filename")
|
||||||
}
|
}
|
||||||
var path string
|
path := flag.Arg(0)
|
||||||
var dir bool
|
|
||||||
var name string
|
|
||||||
if fs, err := os.Stat(flag.Arg(0)); err != nil {
|
|
||||||
log.Fatalln("Can't read file")
|
|
||||||
} else {
|
|
||||||
name = fs.Name()
|
|
||||||
p, err := filepath.Abs(fs.Name())
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalln("Can't read file")
|
|
||||||
}
|
|
||||||
path = p
|
|
||||||
dir = fs.IsDir()
|
|
||||||
}
|
|
||||||
|
|
||||||
if *bg {
|
|
||||||
_, err := syscall.ForkExec(os.Args[0], append([]string{os.Args[0]}, flag.Args()...), &syscall.ProcAttr{Files: []uintptr{0, 1, 2}})
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalln(err)
|
|
||||||
}
|
|
||||||
os.Exit(0)
|
|
||||||
}
|
|
||||||
var username string
|
var username string
|
||||||
userobj, err := user.Current()
|
userobj, err := user.Current()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -124,36 +97,10 @@ func main() {
|
||||||
log.Fatalln("poop", err)
|
log.Fatalln("poop", err)
|
||||||
}
|
}
|
||||||
enc := gob.NewEncoder(ch)
|
enc := gob.NewEncoder(ch)
|
||||||
path = flag.Arg(0)
|
fr := sharethis.NewFileReq(path, *httpport, *remotehost, *sharecount)
|
||||||
|
fmt.Println(fr)
|
||||||
|
|
||||||
hostname, err := os.Hostname()
|
err = enc.Encode(fr)
|
||||||
if err != nil {
|
|
||||||
fmt.Fprintln(os.Stderr, "Could not get hostname with os.Hostname()")
|
|
||||||
hostname = "unknown"
|
|
||||||
}
|
|
||||||
fullpath := fmt.Sprintf("%s@%s:%s", username, hostname, path)
|
|
||||||
hashedpath := fmt.Sprintf("%x", sha256.Sum256([]byte(fullpath)))
|
|
||||||
|
|
||||||
// In the words of weezer, I've got my hashed path.
|
|
||||||
// TODO: Get the remote URL from the remote server instead of rebuilding it locally.
|
|
||||||
var fileurl string
|
|
||||||
fullpath = fmt.Sprintf("%s/%s", hashedpath, name)
|
|
||||||
if *httpport == "443" {
|
|
||||||
fileurl = fmt.Sprintf("https://%s/%s", *remotehost, fullpath)
|
|
||||||
} else if *httpport == "80" {
|
|
||||||
fileurl = fmt.Sprintf("http://%s/%s", *remotehost, fullpath)
|
|
||||||
} else {
|
|
||||||
fileurl = fmt.Sprintf("http://%s:%s/%s", *remotehost, *httpport, fullpath)
|
|
||||||
}
|
|
||||||
|
|
||||||
if dir {
|
|
||||||
fmt.Println(fileurl + ".tar")
|
|
||||||
fmt.Println(fileurl + ".zip")
|
|
||||||
} else {
|
|
||||||
fmt.Println(fileurl)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = enc.Encode(&FileReq{Path: fullpath, ShareCount: *sharecount})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
ch.Close()
|
ch.Close()
|
||||||
|
|
@ -163,23 +110,23 @@ func main() {
|
||||||
mu := &sync.Mutex{}
|
mu := &sync.Mutex{}
|
||||||
|
|
||||||
defer ch.Close()
|
defer ch.Close()
|
||||||
var getchfn func() ssh.NewChannel
|
var getchfn func() *sharethis.NewChannel
|
||||||
if dir {
|
if fr.IsDir {
|
||||||
ncc := connection.HandleChannelOpen(fullpath + ".tar")
|
ncc := connection.HandleChannelOpen(fr.Path + ".tar")
|
||||||
ncc2 := connection.HandleChannelOpen(fullpath + ".zip")
|
ncc2 := connection.HandleChannelOpen(fr.Path + ".zip")
|
||||||
getchfn = func() ssh.NewChannel {
|
getchfn = func() *sharethis.NewChannel {
|
||||||
select {
|
select {
|
||||||
case nc := <-ncc:
|
case nc := <-ncc:
|
||||||
return nc
|
return &sharethis.NewChannel{NewChannel: nc}
|
||||||
case nc := <-ncc2:
|
case nc := <-ncc2:
|
||||||
return nc
|
return &sharethis.NewChannel{NewChannel: nc, DoZip: true}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
ncc := connection.HandleChannelOpen(fullpath)
|
ncc := connection.HandleChannelOpen(fr.Path)
|
||||||
getchfn = func() ssh.NewChannel {
|
getchfn = func() *sharethis.NewChannel {
|
||||||
return <-ncc
|
return &sharethis.NewChannel{<-ncc, false}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -191,7 +138,13 @@ func main() {
|
||||||
fmt.Fprintln(os.Stderr, err)
|
fmt.Fprintln(os.Stderr, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go handleFileRequest(path, ch, mu, sharecount, nc.DoZip)
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleFileRequest(path string, ch ssh.Channel, mu *sync.Mutex, sharecount *uint, dozip bool) {
|
||||||
|
|
||||||
f, err := os.Open(path)
|
f, err := os.Open(path)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -199,8 +152,54 @@ func main() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if fs, _ := f.Stat(); fs.IsDir() {
|
if fs, _ := f.Stat(); fs.IsDir() {
|
||||||
|
if dozip {
|
||||||
|
w := zip.NewWriter(ch)
|
||||||
|
zipfn := func(fi os.FileInfo, path string, file *os.File) error {
|
||||||
|
hdr, err := zip.FileInfoHeader(fi)
|
||||||
|
hdr.Name = path
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Couldn't build zip header: %v", err)
|
||||||
|
}
|
||||||
|
f, err := w.CreateHeader(hdr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Couldn't write zip header: %v", err)
|
||||||
|
}
|
||||||
|
_, err = io.Copy(f, file)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err := WriteFiles(zipfn, fs, ".")
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintln(ch, "Error building tar", err)
|
||||||
|
}
|
||||||
|
w.Close()
|
||||||
|
mu.Lock()
|
||||||
|
if *sharecount == 0 {
|
||||||
|
mu.Unlock()
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
*sharecount--
|
||||||
|
mu.Unlock()
|
||||||
|
ch.Close()
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
} else {
|
||||||
tw := tar.NewWriter(ch)
|
tw := tar.NewWriter(ch)
|
||||||
err := WriteFiles(tw, fs, ".")
|
tarfn := func(fi os.FileInfo, path string, file *os.File) error {
|
||||||
|
hdr, err := tar.FileInfoHeader(fi, "")
|
||||||
|
hdr.Name = path
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Couldn't build tar header: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tw.WriteHeader(hdr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Couldn't write tar header: %v", err)
|
||||||
|
}
|
||||||
|
_, err = io.Copy(tw, file)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err := WriteFiles(tarfn, fs, ".")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Fprintln(ch, "Error building tar", err)
|
fmt.Fprintln(ch, "Error building tar", err)
|
||||||
}
|
}
|
||||||
|
|
@ -216,6 +215,7 @@ func main() {
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
}
|
||||||
_, err = io.Copy(ch, f)
|
_, err = io.Copy(ch, f)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
|
|
@ -229,12 +229,10 @@ func main() {
|
||||||
*sharecount--
|
*sharecount--
|
||||||
mu.Unlock()
|
mu.Unlock()
|
||||||
ch.Close()
|
ch.Close()
|
||||||
}()
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
func WriteFiles(tw *tar.Writer, fi os.FileInfo, path string) error {
|
func WriteFiles(writerfn func(os.FileInfo, string, *os.File) error, fi os.FileInfo, path string) error {
|
||||||
p := path + "/" + fi.Name()
|
p := path + "/" + fi.Name()
|
||||||
f, err := os.Open(p)
|
f, err := os.Open(p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -247,25 +245,15 @@ func WriteFiles(tw *tar.Writer, fi os.FileInfo, path string) error {
|
||||||
}
|
}
|
||||||
var cumulerror error
|
var cumulerror error
|
||||||
for _, nfi := range fis {
|
for _, nfi := range fis {
|
||||||
err := WriteFiles(tw, nfi, p)
|
err := WriteFiles(writerfn, nfi, p)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cumulerror = fmt.Errorf("%vCouldn't write files for %s: %v\n", cumulerror, nfi.Name(), err)
|
cumulerror = fmt.Errorf("%vCouldn't write files for %s: %v\n", cumulerror, nfi.Name(), err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return cumulerror
|
return cumulerror
|
||||||
}
|
}
|
||||||
hdr, err := tar.FileInfoHeader(fi, "")
|
|
||||||
hdr.Name = p
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("Couldn't build tar header: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = tw.WriteHeader(hdr)
|
return writerfn(fi, p, f)
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("Couldn't write tar header: %v", err)
|
|
||||||
}
|
|
||||||
_, err = io.Copy(tw, f)
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func PublicKeyFile(file string) (ssh.AuthMethod, error) {
|
func PublicKeyFile(file string) (ssh.AuthMethod, error) {
|
||||||
|
|
@ -297,16 +285,16 @@ func SSHAgent() (ssh.AuthMethod, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func runServer(host, sshport, httpport, keyfile string) {
|
func runServer(host, sshport, httpport, keyfile string) {
|
||||||
filemap := map[string]*FileReq{}
|
filemap := map[string]*sharethis.FileReq{}
|
||||||
syncy := &sync.RWMutex{}
|
syncy := &sync.RWMutex{}
|
||||||
|
|
||||||
mapget := func(key string) (*FileReq, bool) {
|
mapget := func(key string) (*sharethis.FileReq, bool) {
|
||||||
syncy.RLock()
|
syncy.RLock()
|
||||||
defer syncy.RUnlock()
|
defer syncy.RUnlock()
|
||||||
c, ok := filemap[key]
|
c, ok := filemap[key]
|
||||||
return c, ok
|
return c, ok
|
||||||
}
|
}
|
||||||
mapset := func(key string, fr *FileReq) {
|
mapset := func(key string, fr *sharethis.FileReq) {
|
||||||
syncy.Lock()
|
syncy.Lock()
|
||||||
filemap[key] = fr
|
filemap[key] = fr
|
||||||
syncy.Unlock()
|
syncy.Unlock()
|
||||||
|
|
@ -369,7 +357,7 @@ func runServer(host, sshport, httpport, keyfile string) {
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
dec := gob.NewDecoder(channel)
|
dec := gob.NewDecoder(channel)
|
||||||
var filereq FileReq
|
var filereq sharethis.FileReq
|
||||||
|
|
||||||
for {
|
for {
|
||||||
err := dec.Decode(&filereq)
|
err := dec.Decode(&filereq)
|
||||||
|
|
@ -378,7 +366,7 @@ func runServer(host, sshport, httpport, keyfile string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: also take the sharecount in the map.
|
// TODO: also take the sharecount in the map.
|
||||||
filereq.serverconn = serverConn
|
filereq.ServerConn = serverConn
|
||||||
mapset(filereq.Path, &filereq)
|
mapset(filereq.Path, &filereq)
|
||||||
go func() { serverConn.Wait(); mapdel(filereq.Path) }()
|
go func() { serverConn.Wait(); mapdel(filereq.Path) }()
|
||||||
return
|
return
|
||||||
|
|
@ -406,7 +394,7 @@ func runServer(host, sshport, httpport, keyfile string) {
|
||||||
}
|
}
|
||||||
fp = strings.TrimSuffix(fp, ".tar")
|
fp = strings.TrimSuffix(fp, ".tar")
|
||||||
if fr, ok := mapget(fp); ok {
|
if fr, ok := mapget(fp); ok {
|
||||||
channel, req, err := fr.serverconn.OpenChannel(fp+suffix, nil)
|
channel, req, err := fr.ServerConn.OpenChannel(fp+suffix, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Fprintln(os.Stderr, err)
|
fmt.Fprintln(os.Stderr, err)
|
||||||
return
|
return
|
||||||
89
filereq.go
Normal file
89
filereq.go
Normal file
|
|
@ -0,0 +1,89 @@
|
||||||
|
package sharethis // import "github.com/otremblay/sharethis"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"os/user"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
type FileReq struct {
|
||||||
|
Path string
|
||||||
|
ShareCount uint
|
||||||
|
Username string
|
||||||
|
Hostname string
|
||||||
|
ServerConn *ssh.ServerConn
|
||||||
|
fileurl string
|
||||||
|
httpport string
|
||||||
|
remotehost string
|
||||||
|
localpath string
|
||||||
|
filename string
|
||||||
|
IsDir bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewFileReq(path, httpport, remotehost string, shareCount uint) *FileReq {
|
||||||
|
return getFileReq(DefaultPrefixFn, path, httpport, remotehost, shareCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fr *FileReq) RebuildUrls() {
|
||||||
|
fr = getFileReq(DefaultPrefixFn, fr.localpath, fr.httpport, fr.remotehost, fr.ShareCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fr *FileReq) String() string {
|
||||||
|
if fr.IsDir {
|
||||||
|
return fmt.Sprintf("%s\n%s", fr.fileurl+".tar", fr.fileurl+".zip")
|
||||||
|
} else {
|
||||||
|
return fmt.Sprintf(fr.fileurl)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getRemotePath(prefixFn func() string, path string) string {
|
||||||
|
fullpath := fmt.Sprintf("%s:%s", prefixFn(), path)
|
||||||
|
hashedpath := fmt.Sprintf("%x", sha256.Sum256([]byte(fullpath)))
|
||||||
|
return fmt.Sprintf("%s/%s", hashedpath, filepath.Base(path))
|
||||||
|
}
|
||||||
|
|
||||||
|
func getFileReq(prefixFn func() string, path string, httpport string, remotehost string, sharecount uint) *FileReq {
|
||||||
|
var dir bool
|
||||||
|
if fs, err := os.Stat(path); err != nil {
|
||||||
|
log.Fatalln("Can't read file")
|
||||||
|
} else {
|
||||||
|
p, err := filepath.Abs(fs.Name())
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalln("Can't read file")
|
||||||
|
}
|
||||||
|
dir = fs.IsDir()
|
||||||
|
path = p
|
||||||
|
}
|
||||||
|
remotepath := getRemotePath(prefixFn, path)
|
||||||
|
var fileurl string
|
||||||
|
if httpport == "443" {
|
||||||
|
fileurl = fmt.Sprintf("https://%s/%s", remotehost, remotepath)
|
||||||
|
} else if httpport == "80" {
|
||||||
|
fileurl = fmt.Sprintf("http://%s/%s", remotehost, remotepath)
|
||||||
|
} else {
|
||||||
|
fileurl = fmt.Sprintf("http://%s:%s/%s", remotehost, httpport, remotepath)
|
||||||
|
}
|
||||||
|
return &FileReq{Path: remotepath, ShareCount: sharecount, fileurl: fileurl, localpath: path, IsDir: dir}
|
||||||
|
}
|
||||||
|
|
||||||
|
var DefaultPrefixFn = func() string {
|
||||||
|
hostname, err := os.Hostname()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, "Could not get hostname with os.Hostname()")
|
||||||
|
hostname = "unknown"
|
||||||
|
}
|
||||||
|
var username string
|
||||||
|
userobj, err := user.Current()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintln(os.Stderr, "Could not get user with user.Current()")
|
||||||
|
username = "unknown"
|
||||||
|
} else {
|
||||||
|
username = userobj.Username
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s@%s", username, hostname)
|
||||||
|
}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue