Skip to content

Commit

Permalink
add certReloader.ReloadNow()
Browse files Browse the repository at this point in the history
  • Loading branch information
dreadl0ck committed Oct 3, 2020
1 parent 65bd4f4 commit ef4e749
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 35 deletions.
61 changes: 35 additions & 26 deletions reloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,32 +54,7 @@ func NewCertReloader(certPath, keyPath string, logFile *os.File, cleanup func())
for sig := range sigChan {
if sig == syscall.SIGHUP {
log.Printf("Received SIGHUP, reloading TLS certificate and key from %q and %q", certPath, keyPath)
if err := reloader.maybeReload(); err != nil {

// there was an error reloading the certificate
// rollback files from backup dir
log.Printf("[INFO] simplecert: Keeping old TLS certificate because the new one could not be loaded: %v", err)

// restore private key
backupPrivKey := filepath.Join(c.CacheDir, "backup-"+backupDate, keyFileName)
err = os.Rename(backupPrivKey, filepath.Join(c.CacheDir, keyFileName))
if err != nil {
log.Fatal("[FATAL] simplecert: failed to move key into backup dir: ", err)
}

// restore certificate
backupCert := filepath.Join(c.CacheDir, "backup-"+backupDate, certFileName)
err = os.Rename(backupCert, filepath.Join(c.CacheDir, certFileName))
if err != nil {
log.Fatal("[FATAL] simplecert: failed to move cert into backup dir: ", err)
}

// remove backup directory
err = os.Remove(filepath.Join(c.CacheDir, "backup-"+backupDate))
if err != nil {
log.Fatal("[FATAL] simplecert: failed to remove backup dir: ", err)
}
}
reloader.reload()
} else {
// cleanup
err := logFile.Close()
Expand Down Expand Up @@ -122,3 +97,37 @@ func (reloader *CertReloader) GetCertificateFunc() func(*tls.ClientHelloInfo) (*
return reloader.cert, nil
}
}

// ReloadNow will force reloading the cert from disk
func (reloader *CertReloader) ReloadNow() {
reloader.reload()
}

func (reloader *CertReloader) reload() {
if err := reloader.maybeReload(); err != nil {

// there was an error reloading the certificate
// rollback files from backup dir
log.Printf("[INFO] simplecert: Keeping old TLS certificate because the new one could not be loaded: %v", err)

// restore private key
backupPrivKey := filepath.Join(c.CacheDir, "backup-"+backupDate, keyFileName)
err = os.Rename(backupPrivKey, filepath.Join(c.CacheDir, keyFileName))
if err != nil {
log.Fatal("[FATAL] simplecert: failed to move key into backup dir: ", err)
}

// restore certificate
backupCert := filepath.Join(c.CacheDir, "backup-"+backupDate, certFileName)
err = os.Rename(backupCert, filepath.Join(c.CacheDir, certFileName))
if err != nil {
log.Fatal("[FATAL] simplecert: failed to move cert into backup dir: ", err)
}

// remove backup directory
err = os.Remove(filepath.Join(c.CacheDir, "backup-"+backupDate))
if err != nil {
log.Fatal("[FATAL] simplecert: failed to remove backup dir: ", err)
}
}
}
22 changes: 13 additions & 9 deletions simplecert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ import (
"context"
"fmt"
"log"
"testing"
"net/http"
"os"
"testing"
"time"

"github.com/foomo/tlsconfig"
Expand All @@ -38,8 +38,10 @@ func TestRenewal(t *testing.T) {
os.RemoveAll("simplecert")

var (
numRenews int
ctx, cancel = context.WithCancel(context.Background())
certReloader *CertReloader
err error
numRenews int
ctx, cancel = context.WithCancel(context.Background())

// init strict tlsConfig
tlsconf = tlsconfig.NewServerTLSConfig(tlsconfig.TLSModeServerStrict)
Expand All @@ -65,8 +67,8 @@ func TestRenewal(t *testing.T) {
cfg.SSLEmail = "me@mail.com"
cfg.DirectoryURL = "https://127.0.0.1:14000/dir"

cfg.RenewBefore = int((90 * 24 * time.Hour) - 1 * time.Minute) // renew if older than 1 minute after initial retrieval
cfg.CheckInterval = 20 * time.Second // check every 20 seconds
cfg.RenewBefore = int((90 * 24 * time.Hour) - 1*time.Minute) // renew if older than 1 minute after initial retrieval
cfg.CheckInterval = 20 * time.Second // check every 20 seconds
cfg.CacheDir = "simplecert"

cfg.WillRenewCertificate = func() {
Expand All @@ -85,11 +87,14 @@ func TestRenewal(t *testing.T) {
ctx, cancel = context.WithCancel(context.Background())
srv = makeServer()

// force reload the updated cert from disk
certReloader.ReloadNow()

go serveProd(ctx, srv)
}

// init config
certReloader, err := Init(cfg, func() {
certReloader, err = Init(cfg, func() {
os.Exit(0)
})
if err != nil {
Expand All @@ -113,16 +118,15 @@ func TestRenewal(t *testing.T) {
serveProd(ctx, srv)

fmt.Println("waiting forever")
<- make(chan bool)
<-make(chan bool)
}

func serveProd(ctx context.Context, srv *http.Server) {

// lets go
//cLog.Fatal(srv.ListenAndServeTLS("", ""))
go func() {
if err := srv.ListenAndServeTLS("", ""); err != nil && err != http.ErrServerClosed {
log.Fatalf("listen:%+s\n", err)
log.Fatalf("listen: %+s\n", err)
}
}()

Expand Down

0 comments on commit ef4e749

Please sign in to comment.