mirror of
https://github.com/caddyserver/caddy.git
synced 2026-05-25 16:22:36 -04:00
Compare commits
156 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8ab447e615 | |||
| 0d8384a9b4 | |||
| e14328b71b | |||
| f5aaa471de | |||
| 0b83014ff8 | |||
| 0684cf8611 | |||
| 1570bc5d03 | |||
| 8811853f6d | |||
| b7028b139f | |||
| 620f9687c8 | |||
| 2c43616781 | |||
| d1171af679 | |||
| 598de9e6d9 | |||
| 393bc2992e | |||
| 33f2b16a1b | |||
| f03ad80701 | |||
| a68b01080c | |||
| e0f1a02c37 | |||
| 2358102c07 | |||
| 1533652b78 | |||
| c7562e46a4 | |||
| 8f583dcf36 | |||
| 09188981c4 | |||
| ae5f013a48 | |||
| b7091650f8 | |||
| 3a810c6502 | |||
| 764c9ec956 | |||
| ce0988f48a | |||
| 1c92557c8b | |||
| 8f7a1d6a25 | |||
| 1b085efa47 | |||
| d9e6e7ffa5 | |||
| 05d0b213a9 | |||
| 6f580c6aa3 | |||
| 1d9a094315 | |||
| f6e50890b3 | |||
| 22dfb140d0 | |||
| 15455e5a7e | |||
| f46da403d8 | |||
| 4f5df39bdd | |||
| 1f8d1df4ec | |||
| dd83687447 | |||
| 3ce3f3a96a | |||
| 86060ef9b4 | |||
| d3e3fc533f | |||
| 03b10f9c8e | |||
| f7757da7ed | |||
| 13f9c34d16 | |||
| 13a54dbdda | |||
| 7ed7a95524 | |||
| d47b041923 | |||
| dfbc2e81e3 | |||
| 9edc16e4d6 | |||
| 73273c5bf8 | |||
| 93c5256318 | |||
| 3ccad1814e | |||
| 35269572d7 | |||
| a457b35750 | |||
| 5e5f9b0563 | |||
| 16722e4d99 | |||
| 89c20f9a55 | |||
| d3b731e925 | |||
| 3e0695ee31 | |||
| 9239f3cbcc | |||
| b7a7fd4651 | |||
| 06b067b02c | |||
| dfb5aa6dc6 | |||
| f56696f478 | |||
| fcbb90a9af | |||
| be84b74d01 | |||
| bb5b01c911 | |||
| 3ca6bc4a66 | |||
| 053373a385 | |||
| e263566673 | |||
| 6965075825 | |||
| e54dfa49c3 | |||
| accaa378f0 | |||
| 60a0208e8d | |||
| 2aaaa368bb | |||
| 4829cc6aaf | |||
| 553acf93e2 | |||
| f058419042 | |||
| 13268db536 | |||
| 1f7b5abc80 | |||
| c667f81866 | |||
| b321c00a8f | |||
| 9160789b42 | |||
| df7cdc3fae | |||
| 86fd2f22fb | |||
| 148a6f4430 | |||
| b05006663f | |||
| 5f1f8e4ee6 | |||
| ef48e17e79 | |||
| fe03c1aefa | |||
| 078770a5a6 | |||
| 294f6957f0 | |||
| fe664c00ff | |||
| 518edd3cd4 | |||
| b019501b8b | |||
| 2922d09bef | |||
| 97487e6f0d | |||
| 694d2c9b2e | |||
| a674c0051a | |||
| 98de336a21 | |||
| 9fe2ef417c | |||
| 88edca65d3 | |||
| 64c18a7c6c | |||
| d2fc045219 | |||
| 917a604094 | |||
| b33b24fc9e | |||
| 4d9ee000c8 | |||
| 2966db7b78 | |||
| 38e65e28d4 | |||
| 73b61af58d | |||
| 858e96f21c | |||
| f379bf3421 | |||
| 1896b420d8 | |||
| 1580169e2b | |||
| 95514da91b | |||
| 18ff8748e7 | |||
| 2ed1dd6afc | |||
| 8039a7127f | |||
| a8dfa9f0b7 | |||
| 33aeb1cb5c | |||
| 8bdd13b594 | |||
| 52316952a5 | |||
| 7c868afd32 | |||
| 4df8028bc3 | |||
| f1eaae9b0d | |||
| 385ea53309 | |||
| 2716e272c1 | |||
| ca34a3e1aa | |||
| 3ee6d30659 | |||
| ef40659c70 | |||
| 6e2de19d9f | |||
| 3afb1ae380 | |||
| 37c852c382 | |||
| 3d01f46efa | |||
| 3a6496c268 | |||
| 64c9f20919 | |||
| d10d8c23c4 | |||
| 3cd36fd47d | |||
| aaec7e469c | |||
| 6f78cc49d1 | |||
| 13dfffd203 | |||
| 5552dcbbc7 | |||
| 37b291f82c | |||
| a6521357e5 | |||
| 269a8b5fce | |||
| 5820356cf6 | |||
| 6b3c2212a1 | |||
| 703cf7bf8b | |||
| 3e00e18adc | |||
| 6c17e4d4c8 | |||
| 388ff6bc0a | |||
| 8f0b44b8a4 |
@@ -10,5 +10,10 @@
|
||||
# go fmt will enforce this, but in case a user has not called "go fmt" allow GIT to catch this:
|
||||
*.go text eol=lf core.whitespace whitespace=indent-with-non-tab,trailing-space,tabwidth=4
|
||||
|
||||
*.txt text eol=lf core.whitespace whitespace=tab-in-indent,trailing-space,tabwidth=2
|
||||
*.tpl text eol=lf core.whitespace whitespace=tab-in-indent,trailing-space,tabwidth=2
|
||||
*.htm text eol=lf core.whitespace whitespace=tab-in-indent,trailing-space,tabwidth=2
|
||||
*.html text eol=lf core.whitespace whitespace=tab-in-indent,trailing-space,tabwidth=2
|
||||
*.md text eol=lf core.whitespace whitespace=tab-in-indent,trailing-space,tabwidth=2
|
||||
*.yml text eol=lf core.whitespace whitespace=tab-in-indent,trailing-space,tabwidth=2
|
||||
.git* text eol=auto core.whitespace whitespace=trailing-space
|
||||
|
||||
@@ -103,7 +103,7 @@ While we really do value your requests and implement many of them, not all featu
|
||||
|
||||
### Improving documentation
|
||||
|
||||
Caddy's documentation is available at [https://caddyserver.com/docs](https://caddyserver.com/docs). If you would like to make a fix to the docs, feel free to contribute at the [caddyserver/website](https://github.com/caddyserver/website) repository!
|
||||
Caddy's documentation is available at [https://caddyserver.com/docs](https://caddyserver.com/docs). If you would like to make a fix to the docs, please submit an issue here describing the change to make.
|
||||
|
||||
Note that plugin documentation is not hosted by the Caddy website, other than basic usage examples. They are managed by the individual plugin authors, and you will have to contact them to change their documentation.
|
||||
|
||||
|
||||
+2
-1
@@ -13,9 +13,10 @@ access.log
|
||||
|
||||
/*.conf
|
||||
Caddyfile
|
||||
!caddyfile/
|
||||
|
||||
og_static/
|
||||
|
||||
.vscode/
|
||||
|
||||
*.bat
|
||||
*.bat
|
||||
|
||||
+1
-2
@@ -23,9 +23,8 @@ before_install:
|
||||
install:
|
||||
- if [ "$TRAVIS_PULL_REQUEST" = "false" ]; then bash dist/gitcookie.sh; fi
|
||||
- go get -t ./...
|
||||
- go get github.com/golang/lint/golint
|
||||
- go get golang.org/x/lint/golint
|
||||
- go get github.com/FiloSottile/vendorcheck
|
||||
# Install gometalinter
|
||||
- go get github.com/alecthomas/gometalinter
|
||||
|
||||
script:
|
||||
|
||||
@@ -25,6 +25,12 @@ Caddy is a **production-ready** open-source web server that is fast, easy to use
|
||||
|
||||
Available for Windows, Mac, Linux, BSD, Solaris, and [Android](https://github.com/mholt/caddy/wiki/Running-Caddy-on-Android).
|
||||
|
||||
<p align="center">
|
||||
<b>Thanks to our special sponsor:</b>
|
||||
<br><br>
|
||||
<a href="https://relicabackup.com"><img src="https://caddyserver.com/resources/images/sponsors/relica.png" width="220" alt="Relica - Cross-platform file backup to the cloud, local disks, or other computers"></a>
|
||||
</p>
|
||||
|
||||
## Menu
|
||||
|
||||
- [Features](#features)
|
||||
@@ -51,17 +57,26 @@ Available for Windows, Mac, Linux, BSD, Solaris, and [Android](https://github.co
|
||||
Altogether, Caddy can do things other web servers simply cannot do. Its features and plugins save you time and mistakes, and will cheer you up. Your Caddy instance takes care of the details for you!
|
||||
|
||||
|
||||
<p align="center">
|
||||
<b>Powered by</b>
|
||||
<br>
|
||||
<a href="https://github.com/mholt/certmagic"><img src="https://user-images.githubusercontent.com/1128849/49704830-49d37200-fbd5-11e8-8385-767e0cd033c3.png" alt="CertMagic" width="250"></a>
|
||||
</p>
|
||||
|
||||
|
||||
## Install
|
||||
|
||||
Caddy binaries have no dependencies and are available for every platform. Get Caddy either of these ways:
|
||||
Caddy binaries have no dependencies and are available for every platform. Get Caddy any of these ways:
|
||||
|
||||
- **[Download page](https://caddyserver.com/download)** (RECOMMENDED) allows you to customize your build in the browser
|
||||
- **[Latest release](https://github.com/mholt/caddy/releases/latest)** for pre-built, vanilla binaries
|
||||
- **[AWS Marketplace](https://aws.amazon.com/marketplace/pp/B07J1WNK75?qid=1539015041932&sr=0-1&ref_=srh_res_product_title&cl_spe=C)** makes it easy to deploy directly to your cloud environment. <a href="https://aws.amazon.com/marketplace/pp/B07J1WNK75?qid=1539015041932&sr=0-1&ref_=srh_res_product_title&cl_spe=C" target="_blank">
|
||||
<img src="https://s3.amazonaws.com/cloudformation-examples/cloudformation-launch-stack.png" alt="Get Caddy on the AWS Marketplace" height="25"/></a>
|
||||
|
||||
|
||||
## Build
|
||||
|
||||
To build from source you need **[Git](https://git-scm.com/downloads)** and **[Go](https://golang.org/doc/install)** (1.9 or newer). Follow these instruction for fast building:
|
||||
To build from source you need **[Git](https://git-scm.com/downloads)** and **[Go](https://golang.org/doc/install)** (1.10 or newer). Follow these instruction for fast building:
|
||||
|
||||
- Get the source with `go get github.com/mholt/caddy/caddy` and then run `go get github.com/caddyserver/builds`
|
||||
- Now `cd $GOPATH/src/github.com/mholt/caddy/caddy` and run `go run build.go`
|
||||
@@ -70,6 +85,8 @@ Then make sure the `caddy` binary is in your PATH.
|
||||
|
||||
To build for other platforms, use build.go with the `--goos` and `--goarch` flags.
|
||||
|
||||
When building from source, telemetry is enabled by default. You can disable it by changing `enableTelemetry` in run.go before compiling, or use the `-disabled-metrics` flag at runtime to disable only certain metrics.
|
||||
|
||||
|
||||
## Quick Start
|
||||
|
||||
@@ -137,13 +154,13 @@ If you have questions or concerns about Caddy' underlying crypto implementations
|
||||
|
||||
## Contributing
|
||||
|
||||
**[Join our forum](https://caddy.community) where you can chat with other Caddy users and developers!** To get familiar with the code base, try [Caddy code search on Sourcegraph](https://sourcegraph.com/github.com/mholt/caddy/-/search)!
|
||||
**[Join our forum](https://caddy.community) where you can chat with other Caddy users and developers!** To get familiar with the code base, try [Caddy code search on Sourcegraph](https://sourcegraph.com/github.com/mholt/caddy/)!
|
||||
|
||||
Please see our [contributing guidelines](https://github.com/mholt/caddy/blob/master/.github/CONTRIBUTING.md) for instructions. If you want to write a plugin, check out the [developer wiki](https://github.com/mholt/caddy/wiki).
|
||||
|
||||
We use GitHub issues and pull requests only for discussing bug reports and the development of specific changes. We welcome all other topics on the [forum](https://caddy.community)!
|
||||
|
||||
If you want to contribute to the documentation, please submit pull requests to [caddyserver/website](https://github.com/caddyserver/website).
|
||||
If you want to contribute to the documentation, please [submit an issue](https://github.com/mholt/caddy/issues/new) describing the change that should be made.
|
||||
|
||||
Thanks for making Caddy -- and the Web -- better!
|
||||
|
||||
|
||||
+4
-5
@@ -10,17 +10,16 @@ clone_folder: c:\gopath\src\github.com\mholt\caddy
|
||||
environment:
|
||||
GOPATH: c:\gopath
|
||||
|
||||
stack: go 1.11
|
||||
|
||||
install:
|
||||
- rmdir c:\go /s /q
|
||||
- appveyor DownloadFile https://storage.googleapis.com/golang/go1.9.windows-amd64.zip
|
||||
- 7z x go1.9.windows-amd64.zip -y -oC:\ > NUL
|
||||
- set PATH=%GOPATH%\bin;%PATH%
|
||||
- set PATH=C:\msys64\mingw64\bin;%PATH%
|
||||
- go version
|
||||
- go env
|
||||
- go get -t ./...
|
||||
- go get github.com/golang/lint/golint
|
||||
- go get golang.org/x/lint/golint
|
||||
- go get github.com/FiloSottile/vendorcheck
|
||||
# Install gometalinter
|
||||
- go get github.com/alecthomas/gometalinter
|
||||
|
||||
build: off
|
||||
|
||||
@@ -41,9 +41,12 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/mholt/caddy/caddyfile"
|
||||
"github.com/mholt/caddy/telemetry"
|
||||
"github.com/mholt/certmagic"
|
||||
)
|
||||
|
||||
// Configurable application parameters
|
||||
@@ -107,11 +110,12 @@ type Instance struct {
|
||||
servers []ServerListener
|
||||
|
||||
// these callbacks execute when certain events occur
|
||||
onFirstStartup []func() error // starting, not as part of a restart
|
||||
onStartup []func() error // starting, even as part of a restart
|
||||
onRestart []func() error // before restart commences
|
||||
onShutdown []func() error // stopping, even as part of a restart
|
||||
onFinalShutdown []func() error // stopping, not as part of a restart
|
||||
OnFirstStartup []func() error // starting, not as part of a restart
|
||||
OnStartup []func() error // starting, even as part of a restart
|
||||
OnRestart []func() error // before restart commences
|
||||
OnRestartFailed []func() error // if restart failed
|
||||
OnShutdown []func() error // stopping, even as part of a restart
|
||||
OnFinalShutdown []func() error // stopping, not as part of a restart
|
||||
|
||||
// storing values on an instance is preferable to
|
||||
// global state because these will get garbage-
|
||||
@@ -122,6 +126,7 @@ type Instance struct {
|
||||
StorageMu sync.RWMutex
|
||||
}
|
||||
|
||||
// Instances returns the list of instances.
|
||||
func Instances() []*Instance {
|
||||
return instances
|
||||
}
|
||||
@@ -160,13 +165,13 @@ func (i *Instance) Stop() error {
|
||||
// the rest. All the non-nil errors will be returned.
|
||||
func (i *Instance) ShutdownCallbacks() []error {
|
||||
var errs []error
|
||||
for _, shutdownFunc := range i.onShutdown {
|
||||
for _, shutdownFunc := range i.OnShutdown {
|
||||
err := shutdownFunc()
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
for _, finalShutdownFunc := range i.onFinalShutdown {
|
||||
for _, finalShutdownFunc := range i.OnFinalShutdown {
|
||||
err := finalShutdownFunc()
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
@@ -184,9 +189,26 @@ func (i *Instance) Restart(newCaddyfile Input) (*Instance, error) {
|
||||
i.wg.Add(1)
|
||||
defer i.wg.Done()
|
||||
|
||||
var err error
|
||||
// if something went wrong on restart then run onRestartFailed callbacks
|
||||
defer func() {
|
||||
r := recover()
|
||||
if err != nil || r != nil {
|
||||
for _, fn := range i.OnRestartFailed {
|
||||
err = fn()
|
||||
if err != nil {
|
||||
log.Printf("[ERROR] restart failed: %v", err)
|
||||
}
|
||||
}
|
||||
if r != nil {
|
||||
panic(r)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// run restart callbacks
|
||||
for _, fn := range i.onRestart {
|
||||
err := fn()
|
||||
for _, fn := range i.OnRestart {
|
||||
err = fn()
|
||||
if err != nil {
|
||||
return i, err
|
||||
}
|
||||
@@ -222,22 +244,22 @@ func (i *Instance) Restart(newCaddyfile Input) (*Instance, error) {
|
||||
newInst := &Instance{serverType: newCaddyfile.ServerType(), wg: i.wg, Storage: make(map[interface{}]interface{})}
|
||||
|
||||
// attempt to start new instance
|
||||
err := startWithListenerFds(newCaddyfile, newInst, restartFds)
|
||||
err = startWithListenerFds(newCaddyfile, newInst, restartFds)
|
||||
if err != nil {
|
||||
return i, err
|
||||
}
|
||||
|
||||
// success! stop the old instance
|
||||
for _, shutdownFunc := range i.onShutdown {
|
||||
err = i.Stop()
|
||||
if err != nil {
|
||||
return i, err
|
||||
}
|
||||
for _, shutdownFunc := range i.OnShutdown {
|
||||
err = shutdownFunc()
|
||||
if err != nil {
|
||||
return i, err
|
||||
}
|
||||
}
|
||||
err = i.Stop()
|
||||
if err != nil {
|
||||
return i, err
|
||||
}
|
||||
|
||||
// Execute instantiation events
|
||||
EmitEvent(InstanceStartupEvent, newInst)
|
||||
@@ -254,42 +276,6 @@ func (i *Instance) SaveServer(s Server, ln net.Listener) {
|
||||
i.servers = append(i.servers, ServerListener{server: s, listener: ln})
|
||||
}
|
||||
|
||||
// HasListenerWithAddress returns whether this package is
|
||||
// tracking a server using a listener with the address
|
||||
// addr.
|
||||
func HasListenerWithAddress(addr string) bool {
|
||||
instancesMu.Lock()
|
||||
defer instancesMu.Unlock()
|
||||
for _, inst := range instances {
|
||||
for _, sln := range inst.servers {
|
||||
if listenerAddrEqual(sln.listener, addr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// listenerAddrEqual compares a listener's address with
|
||||
// addr. Extra care is taken to match addresses with an
|
||||
// empty hostname portion, as listeners tend to report
|
||||
// [::]:80, for example, when the matching address that
|
||||
// created the listener might be simply :80.
|
||||
func listenerAddrEqual(ln net.Listener, addr string) bool {
|
||||
lnAddr := ln.Addr().String()
|
||||
hostname, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return lnAddr == addr
|
||||
}
|
||||
if lnAddr == net.JoinHostPort("::", port) {
|
||||
return true
|
||||
}
|
||||
if lnAddr == net.JoinHostPort("0.0.0.0", port) {
|
||||
return true
|
||||
}
|
||||
return hostname != "" && lnAddr == addr
|
||||
}
|
||||
|
||||
// TCPServer is a type that can listen and serve connections.
|
||||
// A TCPServer must associate with exactly zero or one net.Listeners.
|
||||
type TCPServer interface {
|
||||
@@ -368,6 +354,11 @@ type GracefulServer interface {
|
||||
// address; you must store the address the
|
||||
// server is to serve on some other way.
|
||||
Address() string
|
||||
|
||||
// WrapListener wraps a listener with the
|
||||
// listener middlewares configured for this
|
||||
// server, if any.
|
||||
WrapListener(net.Listener) net.Listener
|
||||
}
|
||||
|
||||
// Listener is a net.Listener with an underlying file descriptor.
|
||||
@@ -478,6 +469,26 @@ func (i *Instance) Caddyfile() Input {
|
||||
//
|
||||
// This function blocks until all the servers are listening.
|
||||
func Start(cdyfile Input) (*Instance, error) {
|
||||
// set up the clustering plugin, if there is one (and there should
|
||||
// always be one) -- this should be done exactly once, but we can't
|
||||
// do it during init while plugins are still registering, so do it
|
||||
// when starting the first instance)
|
||||
if atomic.CompareAndSwapInt32(&clusterPluginSetup, 0, 1) {
|
||||
clusterPluginName := os.Getenv("CADDY_CLUSTERING")
|
||||
if clusterPluginName == "" {
|
||||
clusterPluginName = "file" // name of default storage plugin as registered in caddytls package
|
||||
}
|
||||
clusterFn, ok := clusterProviders[clusterPluginName]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unrecognized cluster plugin (was it included in the Caddy build?): %s", clusterPluginName)
|
||||
}
|
||||
storage, err := clusterFn()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("constructing cluster plugin %s: %v", clusterPluginName, err)
|
||||
}
|
||||
certmagic.DefaultStorage = storage
|
||||
}
|
||||
|
||||
inst := &Instance{serverType: cdyfile.ServerType(), wg: new(sync.WaitGroup), Storage: make(map[interface{}]interface{})}
|
||||
err := startWithListenerFds(cdyfile, inst, nil)
|
||||
if err != nil {
|
||||
@@ -531,14 +542,14 @@ func startWithListenerFds(cdyfile Input, inst *Instance, restartFds map[string]r
|
||||
// run startup callbacks
|
||||
if !IsUpgrade() && restartFds == nil {
|
||||
// first startup means not a restart or upgrade
|
||||
for _, firstStartupFunc := range inst.onFirstStartup {
|
||||
for _, firstStartupFunc := range inst.OnFirstStartup {
|
||||
err = firstStartupFunc()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, startupFunc := range inst.onStartup {
|
||||
for _, startupFunc := range inst.OnStartup {
|
||||
err = startupFunc()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -615,6 +626,8 @@ func ValidateAndExecuteDirectives(cdyfile Input, inst *Instance, justValidate bo
|
||||
return fmt.Errorf("error inspecting server blocks: %v", err)
|
||||
}
|
||||
|
||||
telemetry.Set("num_server_blocks", len(sblocks))
|
||||
|
||||
return executeDirectives(inst, cdyfile.Path(), stype.Directives(), sblocks, justValidate)
|
||||
}
|
||||
|
||||
@@ -722,6 +735,7 @@ func startServers(serverList []Server, inst *Instance, restartFds map[string]res
|
||||
return err
|
||||
}
|
||||
}
|
||||
ln = gs.WrapListener(ln)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -760,6 +774,7 @@ func startServers(serverList []Server, inst *Instance, restartFds map[string]res
|
||||
return err
|
||||
}
|
||||
}
|
||||
ln = gs.WrapListener(ln)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -798,7 +813,7 @@ func startServers(serverList []Server, inst *Instance, restartFds map[string]res
|
||||
continue
|
||||
}
|
||||
if strings.Contains(err.Error(), "use of closed network connection") {
|
||||
// this error is normal when closing the listener
|
||||
// this error is normal when closing the listener; see https://github.com/golang/go/issues/4373
|
||||
continue
|
||||
}
|
||||
log.Println(err)
|
||||
@@ -854,6 +869,7 @@ func Stop() error {
|
||||
for {
|
||||
instancesMu.Lock()
|
||||
if len(instances) == 0 {
|
||||
instancesMu.Unlock()
|
||||
break
|
||||
}
|
||||
inst := instances[0]
|
||||
@@ -869,7 +885,7 @@ func Stop() error {
|
||||
// explicitly like a common local hostname. addr must only
|
||||
// be a host or a host:port combination.
|
||||
func IsLoopback(addr string) bool {
|
||||
host, _, err := net.SplitHostPort(addr)
|
||||
host, _, err := net.SplitHostPort(strings.ToLower(addr))
|
||||
if err != nil {
|
||||
host = addr // happens if the addr is just a hostname
|
||||
}
|
||||
@@ -998,5 +1014,7 @@ var (
|
||||
DefaultConfigFile = "Caddyfile"
|
||||
)
|
||||
|
||||
var clusterPluginSetup int32 // access atomically
|
||||
|
||||
// CtxKey is a value type for use with context.WithValue.
|
||||
type CtxKey string
|
||||
|
||||
@@ -42,11 +42,13 @@ import (
|
||||
)
|
||||
|
||||
var goos, goarch, goarm string
|
||||
var race bool
|
||||
|
||||
func init() {
|
||||
flag.StringVar(&goos, "goos", "", "GOOS for which to build")
|
||||
flag.StringVar(&goarch, "goarch", "", "GOARCH for which to build")
|
||||
flag.StringVar(&goarm, "goarm", "", "GOARM for which to build")
|
||||
flag.BoolVar(&race, "race", false, "Enable race detector")
|
||||
}
|
||||
|
||||
func main() {
|
||||
@@ -67,6 +69,9 @@ func main() {
|
||||
args := []string{"build", "-ldflags", ldflags}
|
||||
args = append(args, "-asmflags", fmt.Sprintf("-trimpath=%s", gopath))
|
||||
args = append(args, "-gcflags", fmt.Sprintf("-trimpath=%s", gopath))
|
||||
if race {
|
||||
args = append(args, "-race")
|
||||
}
|
||||
cmd := exec.Command("go", args...)
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.Stdout = os.Stdout
|
||||
@@ -77,6 +82,9 @@ func main() {
|
||||
"GOARCH=" + goarch,
|
||||
"GOARM=" + goarm,
|
||||
} {
|
||||
if race && env == "CGO_ENABLED=0" {
|
||||
continue
|
||||
}
|
||||
cmd.Env = append(cmd.Env, env)
|
||||
}
|
||||
|
||||
|
||||
+256
-22
@@ -15,25 +15,28 @@
|
||||
package caddymain
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
|
||||
"github.com/xenolf/lego/acme"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/klauspost/cpuid"
|
||||
"github.com/mholt/caddy"
|
||||
// plug in the HTTP server type
|
||||
_ "github.com/mholt/caddy/caddyhttp"
|
||||
|
||||
"github.com/mholt/caddy/caddytls"
|
||||
"github.com/mholt/caddy/telemetry"
|
||||
"github.com/mholt/certmagic"
|
||||
lumberjack "gopkg.in/natefinch/lumberjack.v2"
|
||||
|
||||
_ "github.com/mholt/caddy/caddyhttp" // plug in the HTTP server type
|
||||
// This is where other plugins get plugged in (imported)
|
||||
)
|
||||
|
||||
@@ -41,15 +44,17 @@ func init() {
|
||||
caddy.TrapSignals()
|
||||
setVersion()
|
||||
|
||||
flag.BoolVar(&caddytls.Agreed, "agree", false, "Agree to the CA's Subscriber Agreement")
|
||||
flag.StringVar(&caddytls.DefaultCAUrl, "ca", "https://acme-v01.api.letsencrypt.org/directory", "URL to certificate authority's ACME server directory")
|
||||
flag.BoolVar(&caddytls.DisableHTTPChallenge, "disable-http-challenge", caddytls.DisableHTTPChallenge, "Disable the ACME HTTP challenge")
|
||||
flag.BoolVar(&caddytls.DisableTLSSNIChallenge, "disable-tls-sni-challenge", caddytls.DisableTLSSNIChallenge, "Disable the ACME TLS-SNI challenge")
|
||||
flag.BoolVar(&certmagic.Agreed, "agree", false, "Agree to the CA's Subscriber Agreement")
|
||||
flag.StringVar(&certmagic.CA, "ca", certmagic.CA, "URL to certificate authority's ACME server directory")
|
||||
flag.BoolVar(&certmagic.DisableHTTPChallenge, "disable-http-challenge", certmagic.DisableHTTPChallenge, "Disable the ACME HTTP challenge")
|
||||
flag.BoolVar(&certmagic.DisableTLSALPNChallenge, "disable-tls-alpn-challenge", certmagic.DisableTLSALPNChallenge, "Disable the ACME TLS-ALPN challenge")
|
||||
flag.StringVar(&disabledMetrics, "disabled-metrics", "", "Comma-separated list of telemetry metrics to disable")
|
||||
flag.StringVar(&conf, "conf", "", "Caddyfile to load (default \""+caddy.DefaultConfigFile+"\")")
|
||||
flag.StringVar(&cpu, "cpu", "100%", "CPU cap")
|
||||
flag.StringVar(&envFile, "env", "", "Path to file with environment variables to load in KEY=VALUE format")
|
||||
flag.BoolVar(&plugins, "plugins", false, "List installed plugins")
|
||||
flag.StringVar(&caddytls.DefaultEmail, "email", "", "Default ACME CA account email address")
|
||||
flag.DurationVar(&acme.HTTPClient.Timeout, "catimeout", acme.HTTPClient.Timeout, "Default ACME CA HTTP timeout")
|
||||
flag.StringVar(&certmagic.Email, "email", "", "Default ACME CA account email address")
|
||||
flag.DurationVar(&certmagic.HTTPTimeout, "catimeout", certmagic.HTTPTimeout, "Default ACME CA HTTP timeout")
|
||||
flag.StringVar(&logfile, "log", "", "Process log file")
|
||||
flag.StringVar(&caddy.PidFile, "pidfile", "", "Path to write pid file")
|
||||
flag.BoolVar(&caddy.Quiet, "quiet", false, "Quiet mode (no initialization output)")
|
||||
@@ -68,7 +73,7 @@ func Run() {
|
||||
|
||||
caddy.AppName = appName
|
||||
caddy.AppVersion = appVersion
|
||||
acme.UserAgent = appName + "/" + appVersion
|
||||
certmagic.UserAgent = appName + "/" + appVersion
|
||||
|
||||
// Set up process log before anything bad happens
|
||||
switch logfile {
|
||||
@@ -87,6 +92,21 @@ func Run() {
|
||||
})
|
||||
}
|
||||
|
||||
//Load all additional envs as soon as possible
|
||||
if err := LoadEnvFromFile(envFile); err != nil {
|
||||
mustLogFatalf("%v", err)
|
||||
}
|
||||
|
||||
// initialize telemetry client
|
||||
if EnableTelemetry {
|
||||
err := initTelemetry()
|
||||
if err != nil {
|
||||
mustLogFatalf("[ERROR] Initializing telemetry: %v", err)
|
||||
}
|
||||
} else if disabledMetrics != "" {
|
||||
mustLogFatalf("[ERROR] Cannot disable specific metrics because telemetry is disabled")
|
||||
}
|
||||
|
||||
// Check for one-time actions
|
||||
if revoke != "" {
|
||||
err := caddytls.Revoke(revoke)
|
||||
@@ -143,6 +163,26 @@ func Run() {
|
||||
// Execute instantiation events
|
||||
caddy.EmitEvent(caddy.InstanceStartupEvent, instance)
|
||||
|
||||
// Begin telemetry (these are no-ops if telemetry disabled)
|
||||
telemetry.Set("caddy_version", appVersion)
|
||||
telemetry.Set("num_listeners", len(instance.Servers()))
|
||||
telemetry.Set("server_type", serverType)
|
||||
telemetry.Set("os", runtime.GOOS)
|
||||
telemetry.Set("arch", runtime.GOARCH)
|
||||
telemetry.Set("cpu", struct {
|
||||
BrandName string `json:"brand_name,omitempty"`
|
||||
NumLogical int `json:"num_logical,omitempty"`
|
||||
AESNI bool `json:"aes_ni,omitempty"`
|
||||
}{
|
||||
BrandName: cpuid.CPU.BrandName,
|
||||
NumLogical: runtime.NumCPU(),
|
||||
AESNI: cpuid.CPU.AesNi(),
|
||||
})
|
||||
if containerized := detectContainer(); containerized {
|
||||
telemetry.Set("container", containerized)
|
||||
}
|
||||
telemetry.StartEmitting()
|
||||
|
||||
// Twiddle your thumbs
|
||||
instance.Wait()
|
||||
}
|
||||
@@ -266,18 +306,209 @@ func setCPU(cpu string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// detectContainer attempts to determine whether the process is
|
||||
// being run inside a container. References:
|
||||
// https://tuhrig.de/how-to-know-you-are-inside-a-docker-container/
|
||||
// https://stackoverflow.com/a/20012536/1048862
|
||||
// https://gist.github.com/anantkamath/623ce7f5432680749e087cf8cfba9b69
|
||||
func detectContainer() bool {
|
||||
if runtime.GOOS != "linux" {
|
||||
return false
|
||||
}
|
||||
|
||||
file, err := os.Open("/proc/1/cgroup")
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
i := 0
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
i++
|
||||
if i > 1000 {
|
||||
return false
|
||||
}
|
||||
|
||||
line := scanner.Text()
|
||||
parts := strings.SplitN(line, ":", 3)
|
||||
if len(parts) < 3 {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.Contains(parts[2], "docker") ||
|
||||
strings.Contains(parts[2], "lxc") ||
|
||||
strings.Contains(parts[2], "moby") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// initTelemetry initializes the telemetry engine.
|
||||
func initTelemetry() error {
|
||||
uuidFilename := filepath.Join(caddy.AssetsPath(), "uuid")
|
||||
if customUUIDFile := os.Getenv("CADDY_UUID_FILE"); customUUIDFile != "" {
|
||||
uuidFilename = customUUIDFile
|
||||
}
|
||||
|
||||
newUUID := func() uuid.UUID {
|
||||
id := uuid.New()
|
||||
err := os.MkdirAll(caddy.AssetsPath(), 0700)
|
||||
if err != nil {
|
||||
log.Printf("[ERROR] Persisting instance UUID: %v", err)
|
||||
return id
|
||||
}
|
||||
err = ioutil.WriteFile(uuidFilename, []byte(id.String()), 0600) // human-readable as a string
|
||||
if err != nil {
|
||||
log.Printf("[ERROR] Persisting instance UUID: %v", err)
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
var id uuid.UUID
|
||||
|
||||
// load UUID from storage, or create one if we don't have one
|
||||
if uuidFile, err := os.Open(uuidFilename); os.IsNotExist(err) {
|
||||
// no UUID exists yet; create a new one and persist it
|
||||
id = newUUID()
|
||||
} else if err != nil {
|
||||
log.Printf("[ERROR] Loading persistent UUID: %v", err)
|
||||
id = newUUID()
|
||||
} else {
|
||||
defer uuidFile.Close()
|
||||
uuidBytes, err := ioutil.ReadAll(uuidFile)
|
||||
if err != nil {
|
||||
log.Printf("[ERROR] Reading persistent UUID: %v", err)
|
||||
id = newUUID()
|
||||
} else {
|
||||
id, err = uuid.ParseBytes(uuidBytes)
|
||||
if err != nil {
|
||||
log.Printf("[ERROR] Parsing UUID: %v", err)
|
||||
id = newUUID()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// parse and check the list of disabled metrics
|
||||
var disabledMetricsSlice []string
|
||||
if len(disabledMetrics) > 0 {
|
||||
if len(disabledMetrics) > 1024 {
|
||||
// mitigate disk space exhaustion at the collection endpoint
|
||||
return fmt.Errorf("too many metrics to disable")
|
||||
}
|
||||
disabledMetricsSlice = strings.Split(disabledMetrics, ",")
|
||||
for i, metric := range disabledMetricsSlice {
|
||||
if metric == "instance_id" || metric == "timestamp" || metric == "disabled_metrics" {
|
||||
return fmt.Errorf("instance_id, timestamp, and disabled_metrics cannot be disabled")
|
||||
}
|
||||
if metric == "" {
|
||||
disabledMetricsSlice = append(disabledMetricsSlice[:i], disabledMetricsSlice[i+1:]...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// initialize telemetry
|
||||
telemetry.Init(id, disabledMetricsSlice)
|
||||
|
||||
// if any metrics were disabled, report which ones (so we know how representative the data is)
|
||||
if len(disabledMetricsSlice) > 0 {
|
||||
telemetry.Set("disabled_metrics", disabledMetricsSlice)
|
||||
log.Printf("[NOTICE] The following telemetry metrics are disabled: %s", disabledMetrics)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadEnvFromFile loads additional envs if file provided and exists
|
||||
// Envs in file should be in KEY=VALUE format
|
||||
func LoadEnvFromFile(envFile string) error {
|
||||
if envFile == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
file, err := os.Open(envFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
envMap, err := ParseEnvFile(file)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for k, v := range envMap {
|
||||
if err := os.Setenv(k, v); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ParseEnvFile implements parse logic for environment files
|
||||
func ParseEnvFile(envInput io.Reader) (map[string]string, error) {
|
||||
envMap := make(map[string]string)
|
||||
|
||||
scanner := bufio.NewScanner(envInput)
|
||||
var line string
|
||||
lineNumber := 0
|
||||
|
||||
for scanner.Scan() {
|
||||
line = strings.TrimSpace(scanner.Text())
|
||||
lineNumber++
|
||||
|
||||
// skip lines starting with comment
|
||||
if strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
// skip empty line
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
fields := strings.SplitN(line, "=", 2)
|
||||
if len(fields) != 2 {
|
||||
return nil, fmt.Errorf("Can't parse line %d; line should be in KEY=VALUE format", lineNumber)
|
||||
}
|
||||
|
||||
if strings.Contains(fields[0], " ") {
|
||||
return nil, fmt.Errorf("Can't parse line %d; KEY contains whitespace", lineNumber)
|
||||
}
|
||||
|
||||
key := fields[0]
|
||||
val := fields[1]
|
||||
|
||||
if key == "" {
|
||||
return nil, fmt.Errorf("Can't parse line %d; KEY can't be empty string", lineNumber)
|
||||
}
|
||||
envMap[key] = val
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return envMap, nil
|
||||
}
|
||||
|
||||
const appName = "Caddy"
|
||||
|
||||
// Flags that control program flow or startup
|
||||
var (
|
||||
serverType string
|
||||
conf string
|
||||
cpu string
|
||||
logfile string
|
||||
revoke string
|
||||
version bool
|
||||
plugins bool
|
||||
validate bool
|
||||
serverType string
|
||||
conf string
|
||||
cpu string
|
||||
envFile string
|
||||
logfile string
|
||||
revoke string
|
||||
version bool
|
||||
plugins bool
|
||||
validate bool
|
||||
disabledMetrics string
|
||||
)
|
||||
|
||||
// Build information obtained with the help of -ldflags
|
||||
@@ -292,3 +523,6 @@ var (
|
||||
gitShortStat string // git diff-index --shortstat
|
||||
gitFilesModified string // git diff-index --name-only HEAD
|
||||
)
|
||||
|
||||
// EnableTelemetry defines whether telemetry is enabled in Run.
|
||||
var EnableTelemetry = true
|
||||
|
||||
@@ -15,7 +15,9 @@
|
||||
package caddymain
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -57,3 +59,34 @@ func TestSetCPU(t *testing.T) {
|
||||
runtime.GOMAXPROCS(currentCPU)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseEnvFile(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want map[string]string
|
||||
wantErr bool
|
||||
}{
|
||||
{"parsing KEY=VALUE", "PORT=4096", map[string]string{"PORT": "4096"}, false},
|
||||
{"empty KEY", "=4096", nil, true},
|
||||
{"one value", "test", nil, true},
|
||||
{"comments skipped", "#TEST=1\nPORT=8888", map[string]string{"PORT": "8888"}, false},
|
||||
{"empty line", "\nPORT=7777", map[string]string{"PORT": "7777"}, false},
|
||||
{"comments with space skipped", " #TEST=1", map[string]string{}, false},
|
||||
{"KEY with space", "PORT =8888", nil, true},
|
||||
{"only spaces", " ", map[string]string{}, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reader := strings.NewReader(tt.input)
|
||||
got, err := ParseEnvFile(reader)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ParseEnvFile() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("ParseEnvFile() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
+69
-38
@@ -15,9 +15,12 @@
|
||||
package caddy
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strconv"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/mholt/caddy/caddyfile"
|
||||
)
|
||||
|
||||
/*
|
||||
@@ -48,6 +51,70 @@ func TestCaddyStartStop(t *testing.T) {
|
||||
}
|
||||
*/
|
||||
|
||||
// CallbackTestContext implements Context interface
|
||||
type CallbackTestContext struct {
|
||||
// If MakeServersFail is set to true then MakeServers returns an error
|
||||
MakeServersFail bool
|
||||
}
|
||||
|
||||
func (h *CallbackTestContext) InspectServerBlocks(name string, sblock []caddyfile.ServerBlock) ([]caddyfile.ServerBlock, error) {
|
||||
return sblock, nil
|
||||
}
|
||||
func (h *CallbackTestContext) MakeServers() ([]Server, error) {
|
||||
if h.MakeServersFail {
|
||||
return make([]Server, 0), fmt.Errorf("MakeServers failed")
|
||||
}
|
||||
return make([]Server, 0), nil
|
||||
}
|
||||
|
||||
func TestCaddyRestartCallbacks(t *testing.T) {
|
||||
for i, test := range []struct {
|
||||
restartFail bool
|
||||
expectedCalls []string
|
||||
}{
|
||||
{false, []string{"OnRestart", "OnShutdown"}},
|
||||
{true, []string{"OnRestart", "OnRestartFailed"}},
|
||||
} {
|
||||
serverName := fmt.Sprintf("%v", i)
|
||||
// RegisterServerType to make successful restart possible
|
||||
RegisterServerType(serverName, ServerType{
|
||||
Directives: func() []string { return []string{} },
|
||||
// If MakeServersFail is true then the restart will fail due to context failure
|
||||
NewContext: func(inst *Instance) Context { return &CallbackTestContext{MakeServersFail: test.restartFail} },
|
||||
})
|
||||
c := NewTestController(serverName, "")
|
||||
c.instance = &Instance{
|
||||
serverType: serverName,
|
||||
wg: new(sync.WaitGroup),
|
||||
}
|
||||
|
||||
// Register callbacks which save the calls order
|
||||
calls := make([]string, 0)
|
||||
c.OnRestart(func() error {
|
||||
calls = append(calls, "OnRestart")
|
||||
return nil
|
||||
})
|
||||
c.OnRestartFailed(func() error {
|
||||
calls = append(calls, "OnRestartFailed")
|
||||
return nil
|
||||
})
|
||||
c.OnShutdown(func() error {
|
||||
calls = append(calls, "OnShutdown")
|
||||
return nil
|
||||
})
|
||||
|
||||
c.instance.Restart(CaddyfileInput{Contents: []byte(""), ServerTypeName: serverName})
|
||||
|
||||
if !reflect.DeepEqual(calls, test.expectedCalls) {
|
||||
t.Errorf("Test %d: Callbacks expected: %v, got: %v", i, test.expectedCalls, calls)
|
||||
}
|
||||
|
||||
c.instance.Stop()
|
||||
c.instance.Wait()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestIsLoopback(t *testing.T) {
|
||||
for i, test := range []struct {
|
||||
input string
|
||||
@@ -135,39 +202,3 @@ func TestIsInternal(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestListenerAddrEqual(t *testing.T) {
|
||||
ln1, err := net.Listen("tcp", "[::]:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer ln1.Close()
|
||||
ln1port := strconv.Itoa(ln1.Addr().(*net.TCPAddr).Port)
|
||||
|
||||
ln2, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer ln2.Close()
|
||||
ln2port := strconv.Itoa(ln2.Addr().(*net.TCPAddr).Port)
|
||||
|
||||
for i, test := range []struct {
|
||||
ln net.Listener
|
||||
addr string
|
||||
expect bool
|
||||
}{
|
||||
{ln1, ":" + ln2port, false},
|
||||
{ln1, "0.0.0.0:" + ln2port, false},
|
||||
{ln1, "0.0.0.0", false},
|
||||
{ln1, ":" + ln1port, true},
|
||||
{ln1, "0.0.0.0:" + ln1port, true},
|
||||
{ln2, ":" + ln2port, false},
|
||||
{ln2, "127.0.0.1:" + ln1port, false},
|
||||
{ln2, "127.0.0.1", false},
|
||||
{ln2, "127.0.0.1:" + ln2port, true},
|
||||
} {
|
||||
if got, want := listenerAddrEqual(test.ln, test.addr), test.expect; got != want {
|
||||
t.Errorf("Test %d (%s == %s): expected %v but was %v", i, test.addr, test.ln.Addr().String(), want, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+28
-32
@@ -20,6 +20,8 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mholt/caddy/telemetry"
|
||||
)
|
||||
|
||||
// Parse parses the input just enough to group tokens, in
|
||||
@@ -249,9 +251,10 @@ func (p *parser) doImport() error {
|
||||
if p.definedSnippets != nil && p.definedSnippets[importPattern] != nil {
|
||||
importedTokens = p.definedSnippets[importPattern]
|
||||
} else {
|
||||
// make path relative to Caddyfile rather than current working directory (issue #867)
|
||||
// and then use glob to get list of matching filenames
|
||||
absFile, err := filepath.Abs(p.Dispenser.filename)
|
||||
// make path relative to the file of the _token_ being processed rather
|
||||
// than current working directory (issue #867) and then use glob to get
|
||||
// list of matching filenames
|
||||
absFile, err := filepath.Abs(p.Dispenser.File())
|
||||
if err != nil {
|
||||
return p.Errf("Failed to get absolute path of file: %s: %v", p.Dispenser.filename, err)
|
||||
}
|
||||
@@ -263,14 +266,19 @@ func (p *parser) doImport() error {
|
||||
} else {
|
||||
globPattern = importPattern
|
||||
}
|
||||
if strings.Count(globPattern, "*") > 1 || strings.Count(globPattern, "?") > 1 ||
|
||||
(strings.Contains(globPattern, "[") && strings.Contains(globPattern, "]")) {
|
||||
// See issue #2096 - a pattern with many glob expansions can hang for too long
|
||||
return p.Errf("Glob pattern may only contain one wildcard (*), but has others: %s", globPattern)
|
||||
}
|
||||
matches, err = filepath.Glob(globPattern)
|
||||
|
||||
if err != nil {
|
||||
return p.Errf("Failed to use import pattern %s: %v", importPattern, err)
|
||||
}
|
||||
if len(matches) == 0 {
|
||||
if strings.Contains(globPattern, "*") {
|
||||
log.Printf("[WARNING] No files matching import pattern: %s", importPattern)
|
||||
if strings.ContainsAny(globPattern, "*?[]") {
|
||||
log.Printf("[WARNING] No files matching import glob pattern: %s", importPattern)
|
||||
} else {
|
||||
return p.Errf("File to import not found: %s", importPattern)
|
||||
}
|
||||
@@ -283,30 +291,6 @@ func (p *parser) doImport() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var importLine int
|
||||
for i, token := range newTokens {
|
||||
if token.Text == "import" {
|
||||
importLine = token.Line
|
||||
continue
|
||||
}
|
||||
if token.Line == importLine {
|
||||
var abs string
|
||||
if filepath.IsAbs(token.Text) {
|
||||
abs = token.Text
|
||||
} else if !filepath.IsAbs(importFile) {
|
||||
abs = filepath.Join(filepath.Dir(absFile), token.Text)
|
||||
} else {
|
||||
abs = filepath.Join(filepath.Dir(importFile), token.Text)
|
||||
}
|
||||
newTokens[i] = Token{
|
||||
Text: abs,
|
||||
Line: token.Line,
|
||||
File: token.File,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
importedTokens = append(importedTokens, newTokens...)
|
||||
}
|
||||
}
|
||||
@@ -359,7 +343,7 @@ func (p *parser) doSingleImport(importFile string) ([]Token, error) {
|
||||
// are loaded into the current server block for later use
|
||||
// by directive setup functions.
|
||||
func (p *parser) directive() error {
|
||||
dir := p.Val()
|
||||
dir := replaceEnvVars(p.Val())
|
||||
nesting := 0
|
||||
|
||||
// TODO: More helpful error message ("did you mean..." or "maybe you need to install its server type")
|
||||
@@ -369,6 +353,7 @@ func (p *parser) directive() error {
|
||||
|
||||
// The directive itself is appended as a relevant token
|
||||
p.block.Tokens[dir] = append(p.block.Tokens[dir], p.tokens[p.cursor])
|
||||
telemetry.AppendUnique("directives", dir)
|
||||
|
||||
for p.Next() {
|
||||
if p.Val() == "{" {
|
||||
@@ -380,6 +365,12 @@ func (p *parser) directive() error {
|
||||
nesting--
|
||||
} else if p.Val() == "}" && nesting == 0 {
|
||||
return p.Err("Unexpected '}' because no matching opening brace")
|
||||
} else if p.Val() == "import" && p.isNewLine() {
|
||||
if err := p.doImport(); err != nil {
|
||||
return err
|
||||
}
|
||||
p.cursor-- // cursor is advanced when we continue, so roll back one more
|
||||
continue
|
||||
}
|
||||
p.tokens[p.cursor].Text = replaceEnvVars(p.tokens[p.cursor].Text)
|
||||
p.block.Tokens[dir] = append(p.block.Tokens[dir], p.tokens[p.cursor])
|
||||
@@ -439,8 +430,13 @@ func replaceEnvVars(s string) string {
|
||||
func replaceEnvReferences(s, refStart, refEnd string) string {
|
||||
index := strings.Index(s, refStart)
|
||||
for index != -1 {
|
||||
endIndex := strings.Index(s, refEnd)
|
||||
if endIndex != -1 {
|
||||
endIndex := strings.Index(s[index:], refEnd)
|
||||
if endIndex == -1 {
|
||||
break
|
||||
}
|
||||
|
||||
endIndex += index
|
||||
if endIndex > index+len(refStart) {
|
||||
ref := s[index : endIndex+len(refEnd)]
|
||||
s = strings.Replace(s, ref, os.Getenv(ref[len(refStart):len(ref)-len(refEnd)]), -1)
|
||||
} else {
|
||||
|
||||
+177
-9
@@ -228,6 +228,17 @@ func TestParseOneAndImport(t *testing.T) {
|
||||
{`""`, false, []string{}, map[string]int{}},
|
||||
|
||||
{``, false, []string{}, map[string]int{}},
|
||||
|
||||
// test cases found by fuzzing!
|
||||
{`import }{$"`, true, []string{}, map[string]int{}},
|
||||
{`import /*/*.txt`, true, []string{}, map[string]int{}},
|
||||
{`import /???/?*?o`, true, []string{}, map[string]int{}},
|
||||
{`import /??`, true, []string{}, map[string]int{}},
|
||||
{`import /[a-z]`, true, []string{}, map[string]int{}},
|
||||
{`import {$}`, true, []string{}, map[string]int{}},
|
||||
{`import {%}`, true, []string{}, map[string]int{}},
|
||||
{`import {$$}`, true, []string{}, map[string]int{}},
|
||||
{`import {%%}`, true, []string{}, map[string]int{}},
|
||||
} {
|
||||
result, err := testParseOne(test.input)
|
||||
|
||||
@@ -360,6 +371,68 @@ func TestRecursiveImport(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDirectiveImport(t *testing.T) {
|
||||
testParseOne := func(input string) (ServerBlock, error) {
|
||||
p := testParser(input)
|
||||
p.Next() // parseOne doesn't call Next() to start, so we must
|
||||
err := p.parseOne()
|
||||
return p.block, err
|
||||
}
|
||||
|
||||
isExpected := func(got ServerBlock) bool {
|
||||
if len(got.Keys) != 1 || got.Keys[0] != "localhost" {
|
||||
t.Errorf("got keys unexpected: expect localhost, got %v", got.Keys)
|
||||
return false
|
||||
}
|
||||
if len(got.Tokens) != 2 {
|
||||
t.Errorf("got wrong number of tokens: expect 2, got %d", len(got.Tokens))
|
||||
return false
|
||||
}
|
||||
if len(got.Tokens["dir1"]) != 1 || len(got.Tokens["proxy"]) != 8 {
|
||||
t.Errorf("got unexpect tokens: %v", got.Tokens)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
directiveFile, err := filepath.Abs("testdata/directive_import_test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = ioutil.WriteFile(directiveFile, []byte(`prop1 1
|
||||
prop2 2`), 0644)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.Remove(directiveFile)
|
||||
|
||||
// import from existing file
|
||||
result, err := testParseOne(`localhost
|
||||
dir1
|
||||
proxy {
|
||||
import testdata/directive_import_test
|
||||
transparent
|
||||
}`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !isExpected(result) {
|
||||
t.Error("directive import failed")
|
||||
}
|
||||
|
||||
// import from nonexisting file
|
||||
_, err = testParseOne(`localhost
|
||||
dir1
|
||||
proxy {
|
||||
import testdata/nonexistent_file
|
||||
transparent
|
||||
}`)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when importing a nonexistent file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAll(t *testing.T) {
|
||||
for i, test := range []struct {
|
||||
input string
|
||||
@@ -441,6 +514,7 @@ func TestEnvironmentReplacement(t *testing.T) {
|
||||
os.Setenv("PORT", "8080")
|
||||
os.Setenv("ADDRESS", "servername.com")
|
||||
os.Setenv("FOOBAR", "foobar")
|
||||
os.Setenv("PARTIAL_DIR", "r1")
|
||||
|
||||
// basic test; unix-style env vars
|
||||
p := testParser(`{$ADDRESS}`)
|
||||
@@ -449,6 +523,13 @@ func TestEnvironmentReplacement(t *testing.T) {
|
||||
t.Errorf("Expected key to be '%s' but was '%s'", expected, actual)
|
||||
}
|
||||
|
||||
// basic test; unix-style env vars
|
||||
p = testParser(`di{$PARTIAL_DIR}`)
|
||||
blocks, _ = p.parseAll()
|
||||
if actual, expected := blocks[0].Keys[0], "dir1"; expected != actual {
|
||||
t.Errorf("Expected key to be '%s' but was '%s'", expected, actual)
|
||||
}
|
||||
|
||||
// multiple vars per token
|
||||
p = testParser(`{$ADDRESS}:{$PORT}`)
|
||||
blocks, _ = p.parseAll()
|
||||
@@ -507,6 +588,13 @@ func TestEnvironmentReplacement(t *testing.T) {
|
||||
if actual, expected := blocks[0].Tokens["dir1"][1].Text, "Test foobar test"; expected != actual {
|
||||
t.Errorf("Expected argument to be '%s' but was '%s'", expected, actual)
|
||||
}
|
||||
|
||||
// after end token
|
||||
p = testParser(":1234\nanswer \"{{ .Name }} {$FOOBAR}\"")
|
||||
blocks, _ = p.parseAll()
|
||||
if actual, expected := blocks[0].Tokens["answer"][1].Text, "{{ .Name }} foobar"; expected != actual {
|
||||
t.Errorf("Expected argument to be '%s' but was '%s'", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func testParser(input string) parser {
|
||||
@@ -516,15 +604,15 @@ func testParser(input string) parser {
|
||||
}
|
||||
|
||||
func TestSnippets(t *testing.T) {
|
||||
p := testParser(`(common) {
|
||||
gzip foo
|
||||
errors stderr
|
||||
|
||||
}
|
||||
http://example.com {
|
||||
import common
|
||||
}
|
||||
`)
|
||||
p := testParser(`
|
||||
(common) {
|
||||
gzip foo
|
||||
errors stderr
|
||||
}
|
||||
http://example.com {
|
||||
import common
|
||||
}
|
||||
`)
|
||||
blocks, err := p.parseAll()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -550,3 +638,83 @@ func TestSnippets(t *testing.T) {
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func writeStringToTempFileOrDie(t *testing.T, str string) (pathToFile string) {
|
||||
file, err := ioutil.TempFile("", t.Name())
|
||||
if err != nil {
|
||||
panic(err) // get a stack trace so we know where this was called from.
|
||||
}
|
||||
if _, err := file.WriteString(str); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := file.Close(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return file.Name()
|
||||
}
|
||||
|
||||
func TestImportedFilesIgnoreNonDirectiveImportTokens(t *testing.T) {
|
||||
fileName := writeStringToTempFileOrDie(t, `
|
||||
http://example.com {
|
||||
# This isn't an import directive, it's just an arg with value 'import'
|
||||
basicauth / import password
|
||||
}
|
||||
`)
|
||||
// Parse the root file that imports the other one.
|
||||
p := testParser(`import ` + fileName)
|
||||
blocks, err := p.parseAll()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for _, b := range blocks {
|
||||
t.Log(b.Keys)
|
||||
t.Log(b.Tokens)
|
||||
}
|
||||
auth := blocks[0].Tokens["basicauth"]
|
||||
line := auth[0].Text + " " + auth[1].Text + " " + auth[2].Text + " " + auth[3].Text
|
||||
if line != "basicauth / import password" {
|
||||
// Previously, it would be changed to:
|
||||
// basicauth / import /path/to/test/dir/password
|
||||
// referencing a file that (probably) doesn't exist and changing the
|
||||
// password!
|
||||
t.Errorf("Expected basicauth tokens to be 'basicauth / import password' but got %#q", line)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSnippetAcrossMultipleFiles(t *testing.T) {
|
||||
// Make the derived Caddyfile that expects (common) to be defined.
|
||||
fileName := writeStringToTempFileOrDie(t, `
|
||||
http://example.com {
|
||||
import common
|
||||
}
|
||||
`)
|
||||
|
||||
// Parse the root file that defines (common) and then imports the other one.
|
||||
p := testParser(`
|
||||
(common) {
|
||||
gzip foo
|
||||
}
|
||||
import ` + fileName + `
|
||||
`)
|
||||
|
||||
blocks, err := p.parseAll()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for _, b := range blocks {
|
||||
t.Log(b.Keys)
|
||||
t.Log(b.Tokens)
|
||||
}
|
||||
if len(blocks) != 1 {
|
||||
t.Fatalf("Expect exactly one server block. Got %d.", len(blocks))
|
||||
}
|
||||
if actual, expected := blocks[0].Keys[0], "http://example.com"; expected != actual {
|
||||
t.Errorf("Expected server name to be '%s' but was '%s'", expected, actual)
|
||||
}
|
||||
if len(blocks[0].Tokens) != 1 {
|
||||
t.Fatalf("Server block should have tokens from import")
|
||||
}
|
||||
if actual, expected := blocks[0].Tokens["gzip"][0].Text, "gzip"; expected != actual {
|
||||
t.Errorf("Expected argument to be '%s' but was '%s'", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,7 +32,7 @@ func setupBind(c *caddy.Controller) error {
|
||||
if !c.Args(&config.ListenHost) {
|
||||
return c.ArgErr()
|
||||
}
|
||||
config.TLS.ListenHost = config.ListenHost // necessary for ACME challenges, see issue #309
|
||||
config.TLS.Manager.ListenHost = config.ListenHost // necessary for ACME challenges, see issue #309
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -32,7 +32,7 @@ func TestSetupBind(t *testing.T) {
|
||||
if got, want := cfg.ListenHost, "1.2.3.4"; got != want {
|
||||
t.Errorf("Expected the config's ListenHost to be %s, was %s", want, got)
|
||||
}
|
||||
if got, want := cfg.TLS.ListenHost, "1.2.3.4"; got != want {
|
||||
if got, want := cfg.TLS.Manager.ListenHost, "1.2.3.4"; got != want {
|
||||
t.Errorf("Expected the TLS config's ListenHost to be %s, was %s", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
+19
-12
@@ -125,6 +125,7 @@ const defaultTemplate = `<!DOCTYPE html>
|
||||
body {
|
||||
font-family: sans-serif;
|
||||
text-rendering: optimizespeed;
|
||||
background-color: #ffffff;
|
||||
}
|
||||
|
||||
a {
|
||||
@@ -145,12 +146,12 @@ header,
|
||||
|
||||
th:first-child,
|
||||
td:first-child {
|
||||
padding-left: 5%;
|
||||
width: 5%;
|
||||
}
|
||||
|
||||
th:last-child,
|
||||
td:last-child {
|
||||
padding-right: 5%;
|
||||
width: 5%;
|
||||
}
|
||||
|
||||
header {
|
||||
@@ -241,20 +242,20 @@ td {
|
||||
font-size: 14px;
|
||||
}
|
||||
|
||||
td:first-child {
|
||||
width: 100%;
|
||||
td:nth-child(2) {
|
||||
width: 80%;
|
||||
}
|
||||
|
||||
td:nth-child(2) {
|
||||
td:nth-child(3) {
|
||||
padding: 0 20px 0 20px;
|
||||
}
|
||||
|
||||
th:last-child,
|
||||
td:last-child {
|
||||
th:nth-child(4),
|
||||
td:nth-child(4) {
|
||||
text-align: right;
|
||||
}
|
||||
|
||||
td:first-child svg {
|
||||
td:nth-child(2) svg {
|
||||
position: absolute;
|
||||
}
|
||||
|
||||
@@ -301,12 +302,12 @@ footer {
|
||||
display: none;
|
||||
}
|
||||
|
||||
td:first-child {
|
||||
td:nth-child(2) {
|
||||
width: auto;
|
||||
}
|
||||
|
||||
th:nth-child(2),
|
||||
td:nth-child(2) {
|
||||
th:nth-child(3),
|
||||
td:nth-child(3) {
|
||||
padding-right: 5%;
|
||||
text-align: right;
|
||||
}
|
||||
@@ -325,7 +326,7 @@ footer {
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<body onload='filter()'>
|
||||
<svg version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" height="0" width="0" style="position: absolute;">
|
||||
<defs>
|
||||
<!-- Folder -->
|
||||
@@ -390,6 +391,7 @@ footer {
|
||||
<table aria-describedby="summary">
|
||||
<thead>
|
||||
<tr>
|
||||
<th></th>
|
||||
<th>
|
||||
{{- if and (eq .Sort "namedirfirst") (ne .Order "desc")}}
|
||||
<a href="?sort=namedirfirst&order=desc{{if ne 0 .ItemsLimitedTo}}&limit={{.ItemsLimitedTo}}{{end}}" class="icon"><svg width="1em" height=".5em" version="1.1" viewBox="0 0 12.922194 6.0358899"><use xlink:href="#up-arrow"></use></svg></a>
|
||||
@@ -425,11 +427,13 @@ footer {
|
||||
<a href="?sort=time&order=asc{{if ne 0 .ItemsLimitedTo}}&limit={{.ItemsLimitedTo}}{{end}}">Modified</a>
|
||||
{{- end}}
|
||||
</th>
|
||||
<th class="hideable"></th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{{- if .CanGoUp}}
|
||||
<tr>
|
||||
<td></td>
|
||||
<td>
|
||||
<a href="..">
|
||||
<span class="goup">Go up</span>
|
||||
@@ -437,10 +441,12 @@ footer {
|
||||
</td>
|
||||
<td>—</td>
|
||||
<td class="hideable">—</td>
|
||||
<td class="hideable"></td>
|
||||
</tr>
|
||||
{{- end}}
|
||||
{{- range .Items}}
|
||||
<tr class="file">
|
||||
<td></td>
|
||||
<td>
|
||||
<a href="{{html .URL}}">
|
||||
{{- if .IsDir}}
|
||||
@@ -457,6 +463,7 @@ footer {
|
||||
<td data-order="{{.Size}}">{{.HumanSize}}</td>
|
||||
{{- end}}
|
||||
<td class="hideable"><time datetime="{{.HumanModTime "2006-01-02T15:04:05Z"}}">{{.HumanModTime "01/02/2006 03:04:05 PM -07:00"}}</time></td>
|
||||
<td class="hideable"></td>
|
||||
</tr>
|
||||
{{- end}}
|
||||
</tbody>
|
||||
|
||||
@@ -46,5 +46,4 @@ import (
|
||||
_ "github.com/mholt/caddy/caddyhttp/timeouts"
|
||||
_ "github.com/mholt/caddy/caddyhttp/websocket"
|
||||
_ "github.com/mholt/caddy/onevent"
|
||||
_ "github.com/mholt/caddy/startupshutdown"
|
||||
)
|
||||
|
||||
@@ -25,9 +25,9 @@ import (
|
||||
// ensure that the standard plugins are in fact plugged in
|
||||
// and registered properly; this is a quick/naive way to do it.
|
||||
func TestStandardPlugins(t *testing.T) {
|
||||
numStandardPlugins := 33 // importing caddyhttp plugs in this many plugins
|
||||
numStandardPlugins := 31 // importing caddyhttp plugs in this many plugins
|
||||
s := caddy.DescribePlugins()
|
||||
if got, want := strings.Count(s, "\n"), numStandardPlugins+5; got != want {
|
||||
if got, want := strings.Count(s, "\n"), numStandardPlugins+7; got != want {
|
||||
t.Errorf("Expected all standard plugins to be plugged in, got:\n%s", s)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -142,7 +142,7 @@ func (h ErrorHandler) recovery(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// Trim file path
|
||||
delim := "/caddy/"
|
||||
delim := "/github.com/mholt/caddy/"
|
||||
pkgPathPos := strings.Index(file, delim)
|
||||
if pkgPathPos > -1 && len(file) > pkgPathPos+len(delim) {
|
||||
file = file[pkgPathPos+len(delim):]
|
||||
|
||||
@@ -33,8 +33,11 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"crypto/tls"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
"github.com/mholt/caddy/caddytls"
|
||||
)
|
||||
|
||||
// Handler is a middleware type that can handle requests as a FastCGI client.
|
||||
@@ -239,9 +242,6 @@ func (h Handler) exists(path string) bool {
|
||||
func (h Handler) buildEnv(r *http.Request, rule Rule, fpath string) (map[string]string, error) {
|
||||
var env map[string]string
|
||||
|
||||
// Get absolute path of requested resource
|
||||
absPath := filepath.Join(rule.Root, fpath)
|
||||
|
||||
// Separate remote IP and port; more lenient than net.SplitHostPort
|
||||
var ip, port string
|
||||
if idx := strings.LastIndex(r.RemoteAddr, ":"); idx > -1 {
|
||||
@@ -263,11 +263,13 @@ func (h Handler) buildEnv(r *http.Request, rule Rule, fpath string) (map[string]
|
||||
docURI := fpath[:splitPos+len(rule.SplitPath)]
|
||||
pathInfo := fpath[splitPos+len(rule.SplitPath):]
|
||||
scriptName := fpath
|
||||
scriptFilename := absPath
|
||||
|
||||
// Strip PATH_INFO from SCRIPT_NAME
|
||||
scriptName = strings.TrimSuffix(scriptName, pathInfo)
|
||||
|
||||
// SCRIPT_FILENAME is the absolute path of SCRIPT_NAME
|
||||
scriptFilename := filepath.Join(rule.Root, scriptName)
|
||||
|
||||
// Add vhost path prefix to scriptName. Otherwise, some PHP software will
|
||||
// have difficulty discovering its URL.
|
||||
pathPrefix, _ := r.Context().Value(caddy.CtxKey("path_prefix")).(string)
|
||||
@@ -283,6 +285,11 @@ func (h Handler) buildEnv(r *http.Request, rule Rule, fpath string) (map[string]
|
||||
// Retrieve name of remote user that was set by some downstream middleware such as basicauth.
|
||||
remoteUser, _ := r.Context().Value(httpserver.RemoteUserCtxKey).(string)
|
||||
|
||||
requestScheme := "http"
|
||||
if r.TLS != nil {
|
||||
requestScheme = "https"
|
||||
}
|
||||
|
||||
// Some variables are unused but cleared explicitly to prevent
|
||||
// the parent environment from interfering.
|
||||
env = map[string]string{
|
||||
@@ -299,6 +306,7 @@ func (h Handler) buildEnv(r *http.Request, rule Rule, fpath string) (map[string]
|
||||
"REMOTE_IDENT": "", // Not used
|
||||
"REMOTE_USER": remoteUser,
|
||||
"REQUEST_METHOD": r.Method,
|
||||
"REQUEST_SCHEME": requestScheme,
|
||||
"SERVER_NAME": h.ServerName,
|
||||
"SERVER_PORT": h.ServerPort,
|
||||
"SERVER_PROTOCOL": r.Proto,
|
||||
@@ -323,6 +331,19 @@ func (h Handler) buildEnv(r *http.Request, rule Rule, fpath string) (map[string]
|
||||
// Some web apps rely on knowing HTTPS or not
|
||||
if r.TLS != nil {
|
||||
env["HTTPS"] = "on"
|
||||
// and pass the protocol details in a manner compatible with apache's mod_ssl
|
||||
// (which is why they have a SSL_ prefix and not TLS_).
|
||||
v, ok := tlsProtocolStringToMap[r.TLS.Version]
|
||||
if ok {
|
||||
env["SSL_PROTOCOL"] = v
|
||||
}
|
||||
// and pass the cipher suite in a manner compatible with apache's mod_ssl
|
||||
for k, v := range caddytls.SupportedCiphersMap {
|
||||
if v == r.TLS.CipherSuite {
|
||||
env["SSL_CIPHER"] = k
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add env variables from config (with support for placeholders in values)
|
||||
@@ -465,3 +486,11 @@ type LogError string
|
||||
func (l LogError) Error() string {
|
||||
return string(l)
|
||||
}
|
||||
|
||||
// Map of supported protocols to Apache ssl_mod format
|
||||
// Note that these are slightly different from SupportedProtocols in caddytls/config.go's
|
||||
var tlsProtocolStringToMap = map[uint16]string{
|
||||
tls.VersionTLS10: "TLSv1",
|
||||
tls.VersionTLS11: "TLSv1.1",
|
||||
tls.VersionTLS12: "TLSv1.2",
|
||||
}
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"net/http/fcgi"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
@@ -238,6 +239,21 @@ func TestBuildEnv(t *testing.T) {
|
||||
envExpected = newEnv()
|
||||
envExpected["SCRIPT_NAME"] = "/test/fgci_test.php"
|
||||
testBuildEnv(r, rule, fpath, envExpected)
|
||||
|
||||
// 7. Test SCRIPT_NAME,SCRIPT_FILENAME do not include PATH_INFO
|
||||
fpath = "/fgci_test.php/extra/paths"
|
||||
r = newReq()
|
||||
envExpected = newEnv()
|
||||
envExpected["PATH_INFO"] = "/extra/paths"
|
||||
envExpected["SCRIPT_NAME"] = "/fgci_test.php"
|
||||
envExpected["SCRIPT_FILENAME"] = filepath.FromSlash("/fgci_test.php")
|
||||
testBuildEnv(r, rule, fpath, envExpected)
|
||||
|
||||
// 8. Test REQUEST_SCHEME in env
|
||||
r = newReq()
|
||||
envExpected = newEnv()
|
||||
envExpected["REQUEST_SCHEME"] = "http"
|
||||
testBuildEnv(r, rule, fpath, envExpected)
|
||||
}
|
||||
|
||||
func TestReadTimeout(t *testing.T) {
|
||||
|
||||
@@ -27,6 +27,8 @@ import (
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
)
|
||||
|
||||
var defaultTimeout = 60 * time.Second
|
||||
|
||||
func init() {
|
||||
caddy.RegisterPlugin("fastcgi", caddy.Plugin{
|
||||
ServerType: "http",
|
||||
@@ -76,8 +78,11 @@ func fastcgiParse(c *caddy.Controller) ([]Rule, error) {
|
||||
}
|
||||
|
||||
rule := Rule{
|
||||
Root: absRoot,
|
||||
Path: args[0],
|
||||
Root: absRoot,
|
||||
Path: args[0],
|
||||
ConnectTimeout: defaultTimeout,
|
||||
ReadTimeout: defaultTimeout,
|
||||
SendTimeout: defaultTimeout,
|
||||
}
|
||||
|
||||
upstreams := []string{args[1]}
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/mholt/caddy/caddyhttp/httpserver"
|
||||
@@ -53,6 +54,18 @@ func TestSetup(t *testing.T) {
|
||||
if addr != "127.0.0.1:9000" {
|
||||
t.Errorf("Expected 127.0.0.1:9000 as the Address")
|
||||
}
|
||||
|
||||
if myHandler.Rules[0].ConnectTimeout != 60*time.Second {
|
||||
t.Errorf("Expected default value of 60 seconds")
|
||||
}
|
||||
|
||||
if myHandler.Rules[0].ReadTimeout != 60*time.Second {
|
||||
t.Errorf("Expected default value of 60 seconds")
|
||||
}
|
||||
|
||||
if myHandler.Rules[0].SendTimeout != 60*time.Second {
|
||||
t.Errorf("Expected default value of 60 seconds")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFastcgiParse(t *testing.T) {
|
||||
@@ -64,21 +77,23 @@ func TestFastcgiParse(t *testing.T) {
|
||||
|
||||
{`fastcgi /blog 127.0.0.1:9000 php`,
|
||||
false, []Rule{{
|
||||
Path: "/blog",
|
||||
balancer: &roundRobin{addresses: []string{"127.0.0.1:9000"}},
|
||||
Ext: ".php",
|
||||
SplitPath: ".php",
|
||||
IndexFiles: []string{"index.php"},
|
||||
Path: "/blog",
|
||||
balancer: &roundRobin{addresses: []string{"127.0.0.1:9000"}},
|
||||
Ext: ".php",
|
||||
SplitPath: ".php",
|
||||
IndexFiles: []string{"index.php"},
|
||||
SendTimeout: 60 * time.Second,
|
||||
}}},
|
||||
{`fastcgi / 127.0.0.1:9001 {
|
||||
split .html
|
||||
}`,
|
||||
false, []Rule{{
|
||||
Path: "/",
|
||||
balancer: &roundRobin{addresses: []string{"127.0.0.1:9001"}},
|
||||
Ext: "",
|
||||
SplitPath: ".html",
|
||||
IndexFiles: []string{},
|
||||
Path: "/",
|
||||
balancer: &roundRobin{addresses: []string{"127.0.0.1:9001"}},
|
||||
Ext: "",
|
||||
SplitPath: ".html",
|
||||
IndexFiles: []string{},
|
||||
SendTimeout: 60 * time.Second,
|
||||
}}},
|
||||
{`fastcgi / 127.0.0.1:9001 {
|
||||
split .html
|
||||
@@ -91,6 +106,17 @@ func TestFastcgiParse(t *testing.T) {
|
||||
SplitPath: ".html",
|
||||
IndexFiles: []string{},
|
||||
IgnoredSubPaths: []string{"/admin", "/user"},
|
||||
SendTimeout: 60 * time.Second,
|
||||
}}},
|
||||
{`fastcgi / 127.0.0.1:9001 {
|
||||
send_timeout 30s
|
||||
}`,
|
||||
false, []Rule{{
|
||||
Path: "/",
|
||||
balancer: &roundRobin{addresses: []string{"127.0.0.1:9001"}},
|
||||
Ext: "",
|
||||
IndexFiles: []string{},
|
||||
SendTimeout: 30 * time.Second,
|
||||
}}},
|
||||
}
|
||||
for i, test := range tests {
|
||||
@@ -146,6 +172,11 @@ func TestFastcgiParse(t *testing.T) {
|
||||
t.Errorf("Test %d expected %dth FastCGI IgnoredSubPaths to be %s , but got %s",
|
||||
i, j, test.expectedFastcgiConfig[j].IgnoredSubPaths, actualFastcgiConfig.IgnoredSubPaths)
|
||||
}
|
||||
|
||||
if actualFastcgiConfig.SendTimeout != test.expectedFastcgiConfig[j].SendTimeout {
|
||||
t.Errorf("Test %d expected %dth FastCGI SendTimeout to be %s , but got %s",
|
||||
i, j, test.expectedFastcgiConfig[j].SendTimeout, actualFastcgiConfig.SendTimeout)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ type RequestFilter interface {
|
||||
|
||||
// defaultExtensions is the list of default extensions for which to enable gzipping.
|
||||
var defaultExtensions = []string{"", ".txt", ".htm", ".html", ".css", ".php", ".js", ".json",
|
||||
".md", ".mdown", ".xml", ".svg", ".go", ".cgi", ".py", ".pl", ".aspx", ".asp"}
|
||||
".md", ".mdown", ".xml", ".svg", ".go", ".cgi", ".py", ".pl", ".aspx", ".asp", ".m3u", ".m3u8"}
|
||||
|
||||
// DefaultExtFilter creates an ExtFilter with default extensions.
|
||||
func DefaultExtFilter() ExtFilter {
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/mholt/caddy/caddytls"
|
||||
"github.com/mholt/certmagic"
|
||||
)
|
||||
|
||||
func activateHTTPS(cctx caddy.Context) error {
|
||||
@@ -37,10 +38,13 @@ func activateHTTPS(cctx caddy.Context) error {
|
||||
|
||||
// place certificates and keys on disk
|
||||
for _, c := range ctx.siteConfigs {
|
||||
if c.TLS.OnDemand {
|
||||
if !c.TLS.Managed {
|
||||
continue
|
||||
}
|
||||
if c.TLS.Manager.OnDemand != nil {
|
||||
continue // obtain these certificates on-demand instead
|
||||
}
|
||||
err := c.TLS.ObtainCert(c.TLS.Hostname, operatorPresent)
|
||||
err := c.TLS.Manager.ObtainCert(c.TLS.Hostname, operatorPresent)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -62,9 +66,14 @@ func activateHTTPS(cctx caddy.Context) error {
|
||||
// on the ports we'd need to do ACME before we finish starting; parent process
|
||||
// already running renewal ticker, so renewal won't be missed anyway.)
|
||||
if !caddy.IsUpgrade() {
|
||||
err = caddytls.RenewManagedCertificates(true)
|
||||
if err != nil {
|
||||
return err
|
||||
ctx.instance.StorageMu.RLock()
|
||||
certCache, ok := ctx.instance.Storage[caddytls.CertCacheInstStorageKey].(*certmagic.Cache)
|
||||
ctx.instance.StorageMu.RUnlock()
|
||||
if ok && certCache != nil {
|
||||
err = certCache.RenewManagedCertificates(operatorPresent)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -95,13 +104,14 @@ func markQualifiedForAutoHTTPS(configs []*SiteConfig) {
|
||||
// value will always be nil.
|
||||
func enableAutoHTTPS(configs []*SiteConfig, loadCertificates bool) error {
|
||||
for _, cfg := range configs {
|
||||
if cfg == nil || cfg.TLS == nil || !cfg.TLS.Managed || cfg.TLS.OnDemand {
|
||||
if cfg == nil || cfg.TLS == nil || !cfg.TLS.Managed ||
|
||||
cfg.TLS.Manager == nil || cfg.TLS.Manager.OnDemand != nil {
|
||||
continue
|
||||
}
|
||||
cfg.TLS.Enabled = true
|
||||
cfg.Addr.Scheme = "https"
|
||||
if loadCertificates && caddytls.HostQualifies(cfg.Addr.Host) {
|
||||
_, err := cfg.TLS.CacheManagedCertificate(cfg.Addr.Host)
|
||||
if loadCertificates && certmagic.HostQualifies(cfg.TLS.Hostname) {
|
||||
_, err := cfg.TLS.Manager.CacheManagedCertificate(cfg.TLS.Hostname)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -113,7 +123,7 @@ func enableAutoHTTPS(configs []*SiteConfig, loadCertificates bool) error {
|
||||
// Set default port of 443 if not explicitly set
|
||||
if cfg.Addr.Port == "" &&
|
||||
cfg.TLS.Enabled &&
|
||||
(!cfg.TLS.Manual || cfg.TLS.OnDemand) &&
|
||||
(!cfg.TLS.Manual || cfg.TLS.Manager.OnDemand != nil) &&
|
||||
cfg.Addr.Host != "localhost" {
|
||||
cfg.Addr.Port = HTTPSPort
|
||||
}
|
||||
@@ -207,7 +217,7 @@ func redirPlaintextHost(cfg *SiteConfig) *SiteConfig {
|
||||
Addr: Address{Original: addr, Host: host, Port: port},
|
||||
ListenHost: cfg.ListenHost,
|
||||
middleware: []Middleware{redirMiddleware},
|
||||
TLS: &caddytls.Config{AltHTTPPort: cfg.TLS.AltHTTPPort, AltTLSSNIPort: cfg.TLS.AltTLSSNIPort},
|
||||
TLS: &caddytls.Config{Manager: cfg.TLS.Manager},
|
||||
Timeouts: cfg.Timeouts,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/mholt/caddy/caddytls"
|
||||
"github.com/mholt/certmagic"
|
||||
)
|
||||
|
||||
func TestRedirPlaintextHost(t *testing.T) {
|
||||
@@ -175,7 +176,7 @@ func TestMakePlaintextRedirects(t *testing.T) {
|
||||
|
||||
func TestEnableAutoHTTPS(t *testing.T) {
|
||||
configs := []*SiteConfig{
|
||||
{Addr: Address{Host: "example.com"}, TLS: &caddytls.Config{Managed: true}},
|
||||
{Addr: Address{Host: "example.com"}, TLS: &caddytls.Config{Managed: true, Manager: &certmagic.Config{}}},
|
||||
{}, // not managed - no changes!
|
||||
}
|
||||
|
||||
@@ -196,18 +197,18 @@ func TestEnableAutoHTTPS(t *testing.T) {
|
||||
func TestMarkQualifiedForAutoHTTPS(t *testing.T) {
|
||||
// TODO: caddytls.TestQualifiesForManagedTLS and this test share nearly the same config list...
|
||||
configs := []*SiteConfig{
|
||||
{Addr: Address{Host: ""}, TLS: new(caddytls.Config)},
|
||||
{Addr: Address{Host: "localhost"}, TLS: new(caddytls.Config)},
|
||||
{Addr: Address{Host: "123.44.3.21"}, TLS: new(caddytls.Config)},
|
||||
{Addr: Address{Host: "example.com"}, TLS: new(caddytls.Config)},
|
||||
{Addr: Address{Host: ""}, TLS: newManagedConfig()},
|
||||
{Addr: Address{Host: "localhost"}, TLS: newManagedConfig()},
|
||||
{Addr: Address{Host: "123.44.3.21"}, TLS: newManagedConfig()},
|
||||
{Addr: Address{Host: "example.com"}, TLS: newManagedConfig()},
|
||||
{Addr: Address{Host: "example.com"}, TLS: &caddytls.Config{Manual: true}},
|
||||
{Addr: Address{Host: "example.com"}, TLS: &caddytls.Config{ACMEEmail: "off"}},
|
||||
{Addr: Address{Host: "example.com"}, TLS: &caddytls.Config{ACMEEmail: "foo@bar.com"}},
|
||||
{Addr: Address{Host: "example.com", Scheme: "http"}, TLS: new(caddytls.Config)},
|
||||
{Addr: Address{Host: "example.com", Port: "80"}, TLS: new(caddytls.Config)},
|
||||
{Addr: Address{Host: "example.com", Port: "1234"}, TLS: new(caddytls.Config)},
|
||||
{Addr: Address{Host: "example.com", Scheme: "https"}, TLS: new(caddytls.Config)},
|
||||
{Addr: Address{Host: "example.com", Port: "80", Scheme: "https"}, TLS: new(caddytls.Config)},
|
||||
{Addr: Address{Host: "example.com"}, TLS: &caddytls.Config{ACMEEmail: "foo@bar.com", Manager: &certmagic.Config{}}},
|
||||
{Addr: Address{Host: "example.com", Scheme: "http"}, TLS: newManagedConfig()},
|
||||
{Addr: Address{Host: "example.com", Port: "80"}, TLS: newManagedConfig()},
|
||||
{Addr: Address{Host: "example.com", Port: "1234"}, TLS: newManagedConfig()},
|
||||
{Addr: Address{Host: "example.com", Scheme: "https"}, TLS: newManagedConfig()},
|
||||
{Addr: Address{Host: "example.com", Port: "80", Scheme: "https"}, TLS: newManagedConfig()},
|
||||
}
|
||||
expectedManagedCount := 4
|
||||
|
||||
@@ -224,3 +225,7 @@ func TestMarkQualifiedForAutoHTTPS(t *testing.T) {
|
||||
t.Errorf("Expected %d managed configs, but got %d", expectedManagedCount, count)
|
||||
}
|
||||
}
|
||||
|
||||
func newManagedConfig() *caddytls.Config {
|
||||
return &caddytls.Config{Manager: &certmagic.Config{}}
|
||||
}
|
||||
|
||||
@@ -44,6 +44,7 @@ type Logger struct {
|
||||
V4ipMask net.IPMask
|
||||
V6ipMask net.IPMask
|
||||
IPMaskExists bool
|
||||
Exceptions []string
|
||||
}
|
||||
|
||||
// NewTestLogger creates logger suitable for testing purposes
|
||||
@@ -84,6 +85,17 @@ func (l Logger) MaskIP(ip string) string {
|
||||
|
||||
}
|
||||
|
||||
// ShouldLog returns true if the path is not exempted from
|
||||
// being logged (i.e. it is not found in l.Exceptions).
|
||||
func (l Logger) ShouldLog(path string) bool {
|
||||
for _, exc := range l.Exceptions {
|
||||
if Path(path).Matches(exc) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Attach binds logger Start and Close functions to
|
||||
// controller's OnStartup and OnShutdown hooks.
|
||||
func (l *Logger) Attach(controller *caddy.Controller) {
|
||||
|
||||
@@ -24,6 +24,9 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/mholt/caddy/caddytls"
|
||||
"github.com/mholt/caddy/telemetry"
|
||||
)
|
||||
|
||||
// tlsHandler is a http.Handler that will inject a value
|
||||
@@ -49,6 +52,9 @@ type tlsHandler struct {
|
||||
// Halderman, et. al. in "The Security Impact of HTTPS Interception" (NDSS '17):
|
||||
// https://jhalderm.com/pub/papers/interception-ndss17.pdf
|
||||
func (h *tlsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// TODO: one request per connection, we should report UA in connection with
|
||||
// handshake (reported in caddytls package) and our MITM assessment
|
||||
|
||||
if h.listener == nil {
|
||||
h.next.ServeHTTP(w, r)
|
||||
return
|
||||
@@ -59,11 +65,16 @@ func (h *tlsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
h.listener.helloInfosMu.RUnlock()
|
||||
|
||||
ua := r.Header.Get("User-Agent")
|
||||
uaHash := telemetry.FastHash([]byte(ua))
|
||||
|
||||
// report this request's UA in connection with this ClientHello
|
||||
go telemetry.AppendUnique("tls_client_hello_ua:"+caddytls.ClientHelloInfo(info).Key(), uaHash)
|
||||
|
||||
var checked, mitm bool
|
||||
if r.Header.Get("X-BlueCoat-Via") != "" || // Blue Coat (masks User-Agent header to generic values)
|
||||
r.Header.Get("X-FCCKV2") != "" || // Fortinet
|
||||
info.advertisesHeartbeatSupport() { // no major browsers have ever implemented Heartbeat
|
||||
// TODO: Move the heartbeat check into each "looksLike" function...
|
||||
checked = true
|
||||
mitm = true
|
||||
} else if strings.Contains(ua, "Edge") || strings.Contains(ua, "MSIE") ||
|
||||
@@ -97,6 +108,13 @@ func (h *tlsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if checked {
|
||||
r = r.WithContext(context.WithValue(r.Context(), MitmCtxKey, mitm))
|
||||
if mitm {
|
||||
go telemetry.AppendUnique("http_mitm", "likely")
|
||||
} else {
|
||||
go telemetry.AppendUnique("http_mitm", "unlikely")
|
||||
}
|
||||
} else {
|
||||
go telemetry.AppendUnique("http_mitm", "unknown")
|
||||
}
|
||||
|
||||
if mitm && h.closeOnMITM {
|
||||
@@ -195,6 +213,11 @@ func (c *clientHelloConn) Read(b []byte) (n int, err error) {
|
||||
c.listener.helloInfos[c.Conn.RemoteAddr().String()] = rawParsed
|
||||
c.listener.helloInfosMu.Unlock()
|
||||
|
||||
// report this ClientHello to telemetry
|
||||
chKey := caddytls.ClientHelloInfo(rawParsed).Key()
|
||||
go telemetry.SetNested("tls_client_hello", chKey, rawParsed)
|
||||
go telemetry.AppendUnique("tls_client_hello_count", chKey)
|
||||
|
||||
c.readHello = true
|
||||
return
|
||||
}
|
||||
@@ -215,6 +238,7 @@ func parseRawClientHello(data []byte) (info rawHelloInfo) {
|
||||
if len(data) < 42 {
|
||||
return
|
||||
}
|
||||
info.Version = uint16(data[4])<<8 | uint16(data[5])
|
||||
sessionIDLen := int(data[38])
|
||||
if sessionIDLen > 32 || len(data) < 39+sessionIDLen {
|
||||
return
|
||||
@@ -231,9 +255,9 @@ func parseRawClientHello(data []byte) (info rawHelloInfo) {
|
||||
}
|
||||
numCipherSuites := cipherSuiteLen / 2
|
||||
// read in the cipher suites
|
||||
info.cipherSuites = make([]uint16, numCipherSuites)
|
||||
info.CipherSuites = make([]uint16, numCipherSuites)
|
||||
for i := 0; i < numCipherSuites; i++ {
|
||||
info.cipherSuites[i] = uint16(data[2+2*i])<<8 | uint16(data[3+2*i])
|
||||
info.CipherSuites[i] = uint16(data[2+2*i])<<8 | uint16(data[3+2*i])
|
||||
}
|
||||
data = data[2+cipherSuiteLen:]
|
||||
if len(data) < 1 {
|
||||
@@ -244,7 +268,7 @@ func parseRawClientHello(data []byte) (info rawHelloInfo) {
|
||||
if len(data) < 1+compressionMethodsLen {
|
||||
return
|
||||
}
|
||||
info.compressionMethods = data[1 : 1+compressionMethodsLen]
|
||||
info.CompressionMethods = data[1 : 1+compressionMethodsLen]
|
||||
|
||||
data = data[1+compressionMethodsLen:]
|
||||
|
||||
@@ -272,7 +296,7 @@ func parseRawClientHello(data []byte) (info rawHelloInfo) {
|
||||
}
|
||||
|
||||
// record that the client advertised support for this extension
|
||||
info.extensions = append(info.extensions, extension)
|
||||
info.Extensions = append(info.Extensions, extension)
|
||||
|
||||
switch extension {
|
||||
case extensionSupportedCurves:
|
||||
@@ -285,10 +309,10 @@ func parseRawClientHello(data []byte) (info rawHelloInfo) {
|
||||
return
|
||||
}
|
||||
numCurves := l / 2
|
||||
info.curves = make([]tls.CurveID, numCurves)
|
||||
info.Curves = make([]tls.CurveID, numCurves)
|
||||
d := data[2:]
|
||||
for i := 0; i < numCurves; i++ {
|
||||
info.curves[i] = tls.CurveID(d[0])<<8 | tls.CurveID(d[1])
|
||||
info.Curves[i] = tls.CurveID(d[0])<<8 | tls.CurveID(d[1])
|
||||
d = d[2:]
|
||||
}
|
||||
case extensionSupportedPoints:
|
||||
@@ -300,8 +324,8 @@ func parseRawClientHello(data []byte) (info rawHelloInfo) {
|
||||
if length != l+1 {
|
||||
return
|
||||
}
|
||||
info.points = make([]uint8, l)
|
||||
copy(info.points, data[1:])
|
||||
info.Points = make([]uint8, l)
|
||||
copy(info.Points, data[1:])
|
||||
}
|
||||
|
||||
data = data[length:]
|
||||
@@ -352,18 +376,12 @@ func (l *tlsHelloListener) Accept() (net.Conn, error) {
|
||||
// by Durumeric, Halderman, et. al. in
|
||||
// "The Security Impact of HTTPS Interception":
|
||||
// https://jhalderm.com/pub/papers/interception-ndss17.pdf
|
||||
type rawHelloInfo struct {
|
||||
cipherSuites []uint16
|
||||
extensions []uint16
|
||||
compressionMethods []byte
|
||||
curves []tls.CurveID
|
||||
points []uint8
|
||||
}
|
||||
type rawHelloInfo caddytls.ClientHelloInfo
|
||||
|
||||
// advertisesHeartbeatSupport returns true if info indicates
|
||||
// that the client supports the Heartbeat extension.
|
||||
func (info rawHelloInfo) advertisesHeartbeatSupport() bool {
|
||||
for _, ext := range info.extensions {
|
||||
for _, ext := range info.Extensions {
|
||||
if ext == extensionHeartbeat {
|
||||
return true
|
||||
}
|
||||
@@ -386,31 +404,31 @@ func (info rawHelloInfo) looksLikeFirefox() bool {
|
||||
// Note: Firefox 55+ doesn't appear to advertise 0xFF03 (65283, short headers). It used to be between 5 and 13.
|
||||
// Note: Firefox on Fedora (or RedHat) doesn't include ECC suites because of patent liability.
|
||||
requiredExtensionsOrder := []uint16{23, 65281, 10, 11, 35, 16, 5, 13}
|
||||
if !assertPresenceAndOrdering(requiredExtensionsOrder, info.extensions, true) {
|
||||
if !assertPresenceAndOrdering(requiredExtensionsOrder, info.Extensions, true) {
|
||||
return false
|
||||
}
|
||||
|
||||
// We check for both presence of curves and their ordering.
|
||||
requiredCurves := []tls.CurveID{29, 23, 24, 25}
|
||||
if len(info.curves) < len(requiredCurves) {
|
||||
if len(info.Curves) < len(requiredCurves) {
|
||||
return false
|
||||
}
|
||||
for i := range requiredCurves {
|
||||
if info.curves[i] != requiredCurves[i] {
|
||||
if info.Curves[i] != requiredCurves[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if len(info.curves) > len(requiredCurves) {
|
||||
if len(info.Curves) > len(requiredCurves) {
|
||||
// newer Firefox (55 Nightly?) may have additional curves at end of list
|
||||
allowedCurves := []tls.CurveID{256, 257}
|
||||
for i := range allowedCurves {
|
||||
if info.curves[len(requiredCurves)+i] != allowedCurves[i] {
|
||||
if info.Curves[len(requiredCurves)+i] != allowedCurves[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if hasGreaseCiphers(info.cipherSuites) {
|
||||
if hasGreaseCiphers(info.CipherSuites) {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -437,7 +455,7 @@ func (info rawHelloInfo) looksLikeFirefox() bool {
|
||||
tls.TLS_RSA_WITH_AES_256_CBC_SHA, // 0x35
|
||||
tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA, // 0xa
|
||||
}
|
||||
return assertPresenceAndOrdering(expectedCipherSuiteOrder, info.cipherSuites, false)
|
||||
return assertPresenceAndOrdering(expectedCipherSuiteOrder, info.CipherSuites, false)
|
||||
}
|
||||
|
||||
// looksLikeChrome returns true if info looks like a handshake
|
||||
@@ -478,20 +496,20 @@ func (info rawHelloInfo) looksLikeChrome() bool {
|
||||
TLS_DHE_RSA_WITH_AES_128_CBC_SHA: {}, // 0x33
|
||||
TLS_DHE_RSA_WITH_AES_256_CBC_SHA: {}, // 0x39
|
||||
}
|
||||
for _, ext := range info.cipherSuites {
|
||||
for _, ext := range info.CipherSuites {
|
||||
if _, ok := chromeCipherExclusions[ext]; ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Chrome does not include curve 25 (CurveP521) (as of Chrome 56, Feb. 2017).
|
||||
for _, curve := range info.curves {
|
||||
for _, curve := range info.Curves {
|
||||
if curve == 25 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if !hasGreaseCiphers(info.cipherSuites) {
|
||||
if !hasGreaseCiphers(info.CipherSuites) {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -509,19 +527,19 @@ func (info rawHelloInfo) looksLikeEdge() bool {
|
||||
// More specifically, the OCSP status request extension appears
|
||||
// *directly* before the other two extensions, which occur in that
|
||||
// order. (I contacted the authors for clarification and verified it.)
|
||||
for i, ext := range info.extensions {
|
||||
for i, ext := range info.Extensions {
|
||||
if ext == extensionOCSPStatusRequest {
|
||||
if len(info.extensions) <= i+2 {
|
||||
if len(info.Extensions) <= i+2 {
|
||||
return false
|
||||
}
|
||||
if info.extensions[i+1] != extensionSupportedCurves ||
|
||||
info.extensions[i+2] != extensionSupportedPoints {
|
||||
if info.Extensions[i+1] != extensionSupportedCurves ||
|
||||
info.Extensions[i+2] != extensionSupportedPoints {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, cs := range info.cipherSuites {
|
||||
for _, cs := range info.CipherSuites {
|
||||
// As of Feb. 2017, Edge does not have 0xff, but Avast adds it
|
||||
if cs == scsvRenegotiation {
|
||||
return false
|
||||
@@ -532,7 +550,7 @@ func (info rawHelloInfo) looksLikeEdge() bool {
|
||||
}
|
||||
}
|
||||
|
||||
if hasGreaseCiphers(info.cipherSuites) {
|
||||
if hasGreaseCiphers(info.CipherSuites) {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -558,23 +576,23 @@ func (info rawHelloInfo) looksLikeSafari() bool {
|
||||
|
||||
// We check for the presence and order of the extensions.
|
||||
requiredExtensionsOrder := []uint16{10, 11, 13, 13172, 16, 5, 18, 23}
|
||||
if !assertPresenceAndOrdering(requiredExtensionsOrder, info.extensions, true) {
|
||||
if !assertPresenceAndOrdering(requiredExtensionsOrder, info.Extensions, true) {
|
||||
// Safari on iOS 11 (beta) uses different set/ordering of extensions
|
||||
requiredExtensionsOrderiOS11 := []uint16{65281, 0, 23, 13, 5, 13172, 18, 16, 11, 10}
|
||||
if !assertPresenceAndOrdering(requiredExtensionsOrderiOS11, info.extensions, true) {
|
||||
if !assertPresenceAndOrdering(requiredExtensionsOrderiOS11, info.Extensions, true) {
|
||||
return false
|
||||
}
|
||||
} else {
|
||||
// For these versions of Safari, expect TLS_EMPTY_RENEGOTIATION_INFO_SCSV first.
|
||||
if len(info.cipherSuites) < 1 {
|
||||
if len(info.CipherSuites) < 1 {
|
||||
return false
|
||||
}
|
||||
if info.cipherSuites[0] != scsvRenegotiation {
|
||||
if info.CipherSuites[0] != scsvRenegotiation {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if hasGreaseCiphers(info.cipherSuites) {
|
||||
if hasGreaseCiphers(info.CipherSuites) {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -599,19 +617,19 @@ func (info rawHelloInfo) looksLikeSafari() bool {
|
||||
tls.TLS_RSA_WITH_AES_256_CBC_SHA, // 0x35
|
||||
tls.TLS_RSA_WITH_AES_128_CBC_SHA, // 0x2f
|
||||
}
|
||||
return assertPresenceAndOrdering(expectedCipherSuiteOrder, info.cipherSuites, true)
|
||||
return assertPresenceAndOrdering(expectedCipherSuiteOrder, info.CipherSuites, true)
|
||||
}
|
||||
|
||||
// looksLikeTor returns true if the info looks like a ClientHello from Tor browser
|
||||
// (based on Firefox).
|
||||
func (info rawHelloInfo) looksLikeTor() bool {
|
||||
requiredExtensionsOrder := []uint16{10, 11, 16, 5, 13}
|
||||
if !assertPresenceAndOrdering(requiredExtensionsOrder, info.extensions, true) {
|
||||
if !assertPresenceAndOrdering(requiredExtensionsOrder, info.Extensions, true) {
|
||||
return false
|
||||
}
|
||||
|
||||
// check for session tickets support; Tor doesn't support them to prevent tracking
|
||||
for _, ext := range info.extensions {
|
||||
for _, ext := range info.Extensions {
|
||||
if ext == 35 {
|
||||
return false
|
||||
}
|
||||
@@ -619,12 +637,12 @@ func (info rawHelloInfo) looksLikeTor() bool {
|
||||
|
||||
// We check for both presence of curves and their ordering, including
|
||||
// an optional curve at the beginning (for Tor based on Firefox 52)
|
||||
infoCurves := info.curves
|
||||
if len(info.curves) == 4 {
|
||||
if info.curves[0] != 29 {
|
||||
infoCurves := info.Curves
|
||||
if len(info.Curves) == 4 {
|
||||
if info.Curves[0] != 29 {
|
||||
return false
|
||||
}
|
||||
infoCurves = info.curves[1:]
|
||||
infoCurves = info.Curves[1:]
|
||||
}
|
||||
requiredCurves := []tls.CurveID{23, 24, 25}
|
||||
if len(infoCurves) < len(requiredCurves) {
|
||||
@@ -636,7 +654,7 @@ func (info rawHelloInfo) looksLikeTor() bool {
|
||||
}
|
||||
}
|
||||
|
||||
if hasGreaseCiphers(info.cipherSuites) {
|
||||
if hasGreaseCiphers(info.CipherSuites) {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -663,7 +681,7 @@ func (info rawHelloInfo) looksLikeTor() bool {
|
||||
tls.TLS_RSA_WITH_AES_256_CBC_SHA, // 0x35
|
||||
tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA, // 0xa
|
||||
}
|
||||
return assertPresenceAndOrdering(expectedCipherSuiteOrder, info.cipherSuites, false)
|
||||
return assertPresenceAndOrdering(expectedCipherSuiteOrder, info.CipherSuites, false)
|
||||
}
|
||||
|
||||
// assertPresenceAndOrdering will return true if candidateList contains
|
||||
|
||||
@@ -32,44 +32,48 @@ func TestParseClientHello(t *testing.T) {
|
||||
// curl 7.51.0 (x86_64-apple-darwin16.0) libcurl/7.51.0 SecureTransport zlib/1.2.8
|
||||
inputHex: `010000a6030358a28c73a71bdfc1f09dee13fecdc58805dcce42ac44254df548f14645f7dc2c00004400ffc02cc02bc024c023c00ac009c008c030c02fc028c027c014c013c012009f009e006b0067003900330016009d009c003d003c0035002f000a00af00ae008d008c008b01000039000a00080006001700180019000b00020100000d00120010040102010501060104030203050306030005000501000000000012000000170000`,
|
||||
expected: rawHelloInfo{
|
||||
cipherSuites: []uint16{255, 49196, 49195, 49188, 49187, 49162, 49161, 49160, 49200, 49199, 49192, 49191, 49172, 49171, 49170, 159, 158, 107, 103, 57, 51, 22, 157, 156, 61, 60, 53, 47, 10, 175, 174, 141, 140, 139},
|
||||
extensions: []uint16{10, 11, 13, 5, 18, 23},
|
||||
compressionMethods: []byte{0},
|
||||
curves: []tls.CurveID{23, 24, 25},
|
||||
points: []uint8{0},
|
||||
Version: 0x303,
|
||||
CipherSuites: []uint16{255, 49196, 49195, 49188, 49187, 49162, 49161, 49160, 49200, 49199, 49192, 49191, 49172, 49171, 49170, 159, 158, 107, 103, 57, 51, 22, 157, 156, 61, 60, 53, 47, 10, 175, 174, 141, 140, 139},
|
||||
Extensions: []uint16{10, 11, 13, 5, 18, 23},
|
||||
CompressionMethods: []byte{0},
|
||||
Curves: []tls.CurveID{23, 24, 25},
|
||||
Points: []uint8{0},
|
||||
},
|
||||
},
|
||||
{
|
||||
// Chrome 56
|
||||
inputHex: `010000c003031dae75222dae1433a5a283ddcde8ddabaefbf16d84f250eee6fdff48cdfff8a00000201a1ac02bc02fc02cc030cca9cca8cc14cc13c013c014009c009d002f0035000a010000777a7a0000ff010001000000000e000c0000096c6f63616c686f73740017000000230000000d00140012040308040401050308050501080606010201000500050100000000001200000010000e000c02683208687474702f312e3175500000000b00020100000a000a0008aaaa001d001700182a2a000100`,
|
||||
expected: rawHelloInfo{
|
||||
cipherSuites: []uint16{6682, 49195, 49199, 49196, 49200, 52393, 52392, 52244, 52243, 49171, 49172, 156, 157, 47, 53, 10},
|
||||
extensions: []uint16{31354, 65281, 0, 23, 35, 13, 5, 18, 16, 30032, 11, 10, 10794},
|
||||
compressionMethods: []byte{0},
|
||||
curves: []tls.CurveID{43690, 29, 23, 24},
|
||||
points: []uint8{0},
|
||||
Version: 0x303,
|
||||
CipherSuites: []uint16{6682, 49195, 49199, 49196, 49200, 52393, 52392, 52244, 52243, 49171, 49172, 156, 157, 47, 53, 10},
|
||||
Extensions: []uint16{31354, 65281, 0, 23, 35, 13, 5, 18, 16, 30032, 11, 10, 10794},
|
||||
CompressionMethods: []byte{0},
|
||||
Curves: []tls.CurveID{43690, 29, 23, 24},
|
||||
Points: []uint8{0},
|
||||
},
|
||||
},
|
||||
{
|
||||
// Firefox 51
|
||||
inputHex: `010000bd030375f9022fc3a6562467f3540d68013b2d0b961979de6129e944efe0b35531323500001ec02bc02fcca9cca8c02cc030c00ac009c013c01400330039002f0035000a010000760000000e000c0000096c6f63616c686f737400170000ff01000100000a000a0008001d001700180019000b00020100002300000010000e000c02683208687474702f312e31000500050100000000ff030000000d0020001e040305030603020308040805080604010501060102010402050206020202`,
|
||||
expected: rawHelloInfo{
|
||||
cipherSuites: []uint16{49195, 49199, 52393, 52392, 49196, 49200, 49162, 49161, 49171, 49172, 51, 57, 47, 53, 10},
|
||||
extensions: []uint16{0, 23, 65281, 10, 11, 35, 16, 5, 65283, 13},
|
||||
compressionMethods: []byte{0},
|
||||
curves: []tls.CurveID{29, 23, 24, 25},
|
||||
points: []uint8{0},
|
||||
Version: 0x303,
|
||||
CipherSuites: []uint16{49195, 49199, 52393, 52392, 49196, 49200, 49162, 49161, 49171, 49172, 51, 57, 47, 53, 10},
|
||||
Extensions: []uint16{0, 23, 65281, 10, 11, 35, 16, 5, 65283, 13},
|
||||
CompressionMethods: []byte{0},
|
||||
Curves: []tls.CurveID{29, 23, 24, 25},
|
||||
Points: []uint8{0},
|
||||
},
|
||||
},
|
||||
{
|
||||
// openssl s_client (OpenSSL 0.9.8zh 14 Jan 2016)
|
||||
inputHex: `0100012b03035d385236b8ca7b7946fa0336f164e76bf821ed90e8de26d97cc677671b6f36380000acc030c02cc028c024c014c00a00a500a300a1009f006b006a0069006800390038003700360088008700860085c032c02ec02ac026c00fc005009d003d00350084c02fc02bc027c023c013c00900a400a200a0009e00670040003f003e0033003200310030009a0099009800970045004400430042c031c02dc029c025c00ec004009c003c002f009600410007c011c007c00cc00200050004c012c008001600130010000dc00dc003000a00ff0201000055000b000403000102000a001c001a00170019001c001b0018001a0016000e000d000b000c0009000a00230000000d0020001e060106020603050105020503040104020403030103020303020102020203000f000101`,
|
||||
expected: rawHelloInfo{
|
||||
cipherSuites: []uint16{49200, 49196, 49192, 49188, 49172, 49162, 165, 163, 161, 159, 107, 106, 105, 104, 57, 56, 55, 54, 136, 135, 134, 133, 49202, 49198, 49194, 49190, 49167, 49157, 157, 61, 53, 132, 49199, 49195, 49191, 49187, 49171, 49161, 164, 162, 160, 158, 103, 64, 63, 62, 51, 50, 49, 48, 154, 153, 152, 151, 69, 68, 67, 66, 49201, 49197, 49193, 49189, 49166, 49156, 156, 60, 47, 150, 65, 7, 49169, 49159, 49164, 49154, 5, 4, 49170, 49160, 22, 19, 16, 13, 49165, 49155, 10, 255},
|
||||
extensions: []uint16{11, 10, 35, 13, 15},
|
||||
compressionMethods: []byte{1, 0},
|
||||
curves: []tls.CurveID{23, 25, 28, 27, 24, 26, 22, 14, 13, 11, 12, 9, 10},
|
||||
points: []uint8{0, 1, 2},
|
||||
Version: 0x303,
|
||||
CipherSuites: []uint16{49200, 49196, 49192, 49188, 49172, 49162, 165, 163, 161, 159, 107, 106, 105, 104, 57, 56, 55, 54, 136, 135, 134, 133, 49202, 49198, 49194, 49190, 49167, 49157, 157, 61, 53, 132, 49199, 49195, 49191, 49187, 49171, 49161, 164, 162, 160, 158, 103, 64, 63, 62, 51, 50, 49, 48, 154, 153, 152, 151, 69, 68, 67, 66, 49201, 49197, 49193, 49189, 49166, 49156, 156, 60, 47, 150, 65, 7, 49169, 49159, 49164, 49154, 5, 4, 49170, 49160, 22, 19, 16, 13, 49165, 49155, 10, 255},
|
||||
Extensions: []uint16{11, 10, 35, 13, 15},
|
||||
CompressionMethods: []byte{1, 0},
|
||||
Curves: []tls.CurveID{23, 25, 28, 27, 24, 26, 22, 14, 13, 11, 12, 9, 10},
|
||||
Points: []uint8{0, 1, 2},
|
||||
},
|
||||
},
|
||||
} {
|
||||
@@ -338,8 +342,8 @@ func TestHeuristicFunctionsAndHandler(t *testing.T) {
|
||||
(isEdge && (isChrome || isFirefox || isSafari || isTor)) ||
|
||||
(isTor && (isChrome || isFirefox || isSafari || isEdge)) {
|
||||
t.Errorf("[%s] Test %d: Multiple fingerprinting functions matched: "+
|
||||
"Chrome=%v Firefox=%v Safari=%v Edge=%v Tor=%v\n\tparsed hello dec: %+v\n\tparsed hello hex: %#x\n",
|
||||
client, i, isChrome, isFirefox, isSafari, isEdge, isTor, parsed, parsed)
|
||||
"Chrome=%v Firefox=%v Safari=%v Edge=%v Tor=%v\n\tparsed hello dec: %+v\n",
|
||||
client, i, isChrome, isFirefox, isSafari, isEdge, isTor, parsed)
|
||||
}
|
||||
|
||||
// test the handler and detection results
|
||||
@@ -367,8 +371,8 @@ func TestHeuristicFunctionsAndHandler(t *testing.T) {
|
||||
if got != want {
|
||||
t.Errorf("[%s] Test %d: Expected MITM=%v but got %v (type assertion OK (checked)=%v)",
|
||||
client, i, want, got, checked)
|
||||
t.Errorf("[%s] Test %d: Looks like Chrome=%v Firefox=%v Safari=%v Edge=%v Tor=%v\n\tparsed hello dec: %+v\n\tparsed hello hex: %#x\n",
|
||||
client, i, isChrome, isFirefox, isSafari, isEdge, isTor, parsed, parsed)
|
||||
t.Errorf("[%s] Test %d: Looks like Chrome=%v Firefox=%v Safari=%v Edge=%v Tor=%v\n\tparsed hello dec: %+v\n",
|
||||
client, i, isChrome, isFirefox, isSafari, isEdge, isTor, parsed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+149
-22
@@ -15,6 +15,7 @@
|
||||
package httpserver
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
@@ -22,6 +23,7 @@ import (
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -29,6 +31,8 @@ import (
|
||||
"github.com/mholt/caddy/caddyfile"
|
||||
"github.com/mholt/caddy/caddyhttp/staticfiles"
|
||||
"github.com/mholt/caddy/caddytls"
|
||||
"github.com/mholt/caddy/telemetry"
|
||||
"github.com/mholt/certmagic"
|
||||
)
|
||||
|
||||
const serverType = "http"
|
||||
@@ -65,6 +69,12 @@ func init() {
|
||||
caddy.RegisterParsingCallback(serverType, "root", hideCaddyfile)
|
||||
caddy.RegisterParsingCallback(serverType, "tls", activateHTTPS)
|
||||
caddytls.RegisterConfigGetter(serverType, func(c *caddy.Controller) *caddytls.Config { return GetConfig(c).TLS })
|
||||
|
||||
// disable the caddytls package reporting ClientHellos
|
||||
// to telemetry, since our MITM detector does this but
|
||||
// with more information than the standard lib provides
|
||||
// (as of May 2018)
|
||||
caddytls.ClientHelloTelemetry = false
|
||||
}
|
||||
|
||||
// hideCaddyfile hides the source/origin Caddyfile if it is within the
|
||||
@@ -122,15 +132,17 @@ func (h *httpContext) InspectServerBlocks(sourceFile string, serverBlocks []cadd
|
||||
// For each address in each server block, make a new config
|
||||
for _, sb := range serverBlocks {
|
||||
for _, key := range sb.Keys {
|
||||
key = strings.ToLower(key)
|
||||
if _, dup := h.keysToSiteConfigs[key]; dup {
|
||||
return serverBlocks, fmt.Errorf("duplicate site key: %s", key)
|
||||
}
|
||||
addr, err := standardizeAddress(key)
|
||||
if err != nil {
|
||||
return serverBlocks, err
|
||||
}
|
||||
|
||||
addr = addr.Normalize()
|
||||
key = addr.Key()
|
||||
if _, dup := h.keysToSiteConfigs[key]; dup {
|
||||
return serverBlocks, fmt.Errorf("duplicate site key: %s", key)
|
||||
}
|
||||
|
||||
// Fill in address components from command line so that middleware
|
||||
// have access to the correct information during setup
|
||||
if addr.Host == "" && Host != DefaultHost {
|
||||
@@ -145,7 +157,7 @@ func (h *httpContext) InspectServerBlocks(sourceFile string, serverBlocks []cadd
|
||||
if addrCopy.Port == "" && Port == DefaultPort {
|
||||
addrCopy.Port = Port
|
||||
}
|
||||
addrStr := strings.ToLower(addrCopy.String())
|
||||
addrStr := addrCopy.String()
|
||||
if otherSiteKey, dup := siteAddrs[addrStr]; dup {
|
||||
err := fmt.Errorf("duplicate site address: %s", addrStr)
|
||||
if (addrCopy.Host == Host && Host != DefaultHost) ||
|
||||
@@ -159,12 +171,20 @@ func (h *httpContext) InspectServerBlocks(sourceFile string, serverBlocks []cadd
|
||||
|
||||
// If default HTTP or HTTPS ports have been customized,
|
||||
// make sure the ACME challenge ports match
|
||||
var altHTTPPort, altTLSSNIPort string
|
||||
var altHTTPPort, altTLSALPNPort int
|
||||
if HTTPPort != DefaultHTTPPort {
|
||||
altHTTPPort = HTTPPort
|
||||
portInt, err := strconv.Atoi(HTTPPort)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
altHTTPPort = portInt
|
||||
}
|
||||
if HTTPSPort != DefaultHTTPSPort {
|
||||
altTLSSNIPort = HTTPSPort
|
||||
portInt, err := strconv.Atoi(HTTPSPort)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
altTLSALPNPort = portInt
|
||||
}
|
||||
|
||||
// Make our caddytls.Config, which has a pointer to the
|
||||
@@ -172,8 +192,8 @@ func (h *httpContext) InspectServerBlocks(sourceFile string, serverBlocks []cadd
|
||||
// to use automatic HTTPS when the time comes
|
||||
caddytlsConfig := caddytls.NewConfig(h.instance)
|
||||
caddytlsConfig.Hostname = addr.Host
|
||||
caddytlsConfig.AltHTTPPort = altHTTPPort
|
||||
caddytlsConfig.AltTLSSNIPort = altTLSSNIPort
|
||||
caddytlsConfig.Manager.AltHTTPPort = altHTTPPort
|
||||
caddytlsConfig.Manager.AltTLSALPNPort = altTLSALPNPort
|
||||
|
||||
// Save the config to our master list, and key it for lookups
|
||||
cfg := &SiteConfig{
|
||||
@@ -205,9 +225,41 @@ func (h *httpContext) InspectServerBlocks(sourceFile string, serverBlocks []cadd
|
||||
// MakeServers uses the newly-created siteConfigs to
|
||||
// create and return a list of server instances.
|
||||
func (h *httpContext) MakeServers() ([]caddy.Server, error) {
|
||||
// make sure TLS is disabled for explicitly-HTTP sites
|
||||
// (necessary when HTTP address shares a block containing tls)
|
||||
// make a rough estimate as to whether we're in a "production
|
||||
// environment/system" - start by assuming that most production
|
||||
// servers will set their default CA endpoint to a public,
|
||||
// trusted CA (obviously not a perfect hueristic)
|
||||
var looksLikeProductionCA bool
|
||||
for _, publicCAEndpoint := range caddytls.KnownACMECAs {
|
||||
if strings.Contains(certmagic.CA, publicCAEndpoint) {
|
||||
looksLikeProductionCA = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Iterate each site configuration and make sure that:
|
||||
// 1) TLS is disabled for explicitly-HTTP sites (necessary
|
||||
// when an HTTP address shares a block containing tls)
|
||||
// 2) if QUIC is enabled, TLS ClientAuth is not, because
|
||||
// currently, QUIC does not support ClientAuth (TODO:
|
||||
// revisit this when our QUIC implementation supports it)
|
||||
// 3) if TLS ClientAuth is used, StrictHostMatching is on
|
||||
var atLeastOneSiteLooksLikeProduction bool
|
||||
for _, cfg := range h.siteConfigs {
|
||||
// see if all the addresses (both sites and
|
||||
// listeners) are loopback to help us determine
|
||||
// if this is a "production" instance or not
|
||||
if !atLeastOneSiteLooksLikeProduction {
|
||||
if !caddy.IsLoopback(cfg.Addr.Host) &&
|
||||
!caddy.IsLoopback(cfg.ListenHost) &&
|
||||
(caddytls.QualifiesForManagedTLS(cfg) ||
|
||||
certmagic.HostQualifies(cfg.Addr.Host)) {
|
||||
atLeastOneSiteLooksLikeProduction = true
|
||||
}
|
||||
}
|
||||
|
||||
// make sure TLS is disabled for explicitly-HTTP sites
|
||||
// (necessary when HTTP address shares a block containing tls)
|
||||
if !cfg.TLS.Enabled {
|
||||
continue
|
||||
}
|
||||
@@ -222,12 +274,23 @@ func (h *httpContext) MakeServers() ([]caddy.Server, error) {
|
||||
// is incorrect for this site.
|
||||
cfg.Addr.Scheme = "https"
|
||||
}
|
||||
if cfg.Addr.Port == "" && ((!cfg.TLS.Manual && !cfg.TLS.SelfSigned) || cfg.TLS.OnDemand) {
|
||||
if cfg.Addr.Port == "" && ((!cfg.TLS.Manual && !cfg.TLS.SelfSigned) || cfg.TLS.Manager.OnDemand != nil) {
|
||||
// this is vital, otherwise the function call below that
|
||||
// sets the listener address will use the default port
|
||||
// instead of 443 because it doesn't know about TLS.
|
||||
cfg.Addr.Port = HTTPSPort
|
||||
}
|
||||
if cfg.TLS.ClientAuth != tls.NoClientCert {
|
||||
if QUIC {
|
||||
return nil, fmt.Errorf("cannot enable TLS client authentication with QUIC, because QUIC does not yet support it")
|
||||
}
|
||||
// this must be enabled so that a client cannot connect
|
||||
// using SNI for another site on this listener that
|
||||
// does NOT require ClientAuth, and then send HTTP
|
||||
// requests with the Host header of this site which DOES
|
||||
// require client auth, thus bypassing it...
|
||||
cfg.StrictHostMatching = true
|
||||
}
|
||||
}
|
||||
|
||||
// we must map (group) each config to a bind address
|
||||
@@ -246,22 +309,48 @@ func (h *httpContext) MakeServers() ([]caddy.Server, error) {
|
||||
servers = append(servers, s)
|
||||
}
|
||||
|
||||
// NOTE: This value is only a "good guess". Quite often, development
|
||||
// environments will use internal DNS or a local hosts file to serve
|
||||
// real-looking domains in local development. We can't easily tell
|
||||
// which without doing a DNS lookup, so this guess is definitely naive,
|
||||
// and if we ever want a better guess, we will have to do DNS lookups.
|
||||
deploymentGuess := "dev"
|
||||
if looksLikeProductionCA && atLeastOneSiteLooksLikeProduction {
|
||||
deploymentGuess = "prod"
|
||||
}
|
||||
telemetry.Set("http_deployment_guess", deploymentGuess)
|
||||
telemetry.Set("http_num_sites", len(h.siteConfigs))
|
||||
|
||||
return servers, nil
|
||||
}
|
||||
|
||||
// normalizedKey returns "normalized" key representation:
|
||||
// scheme and host names are lowered, everything else stays the same
|
||||
func normalizedKey(key string) string {
|
||||
addr, err := standardizeAddress(key)
|
||||
if err != nil {
|
||||
return key
|
||||
}
|
||||
return addr.Normalize().Key()
|
||||
}
|
||||
|
||||
// GetConfig gets the SiteConfig that corresponds to c.
|
||||
// If none exist (should only happen in tests), then a
|
||||
// new, empty one will be created.
|
||||
func GetConfig(c *caddy.Controller) *SiteConfig {
|
||||
ctx := c.Context().(*httpContext)
|
||||
key := strings.ToLower(c.Key)
|
||||
key := normalizedKey(c.Key)
|
||||
if cfg, ok := ctx.keysToSiteConfigs[key]; ok {
|
||||
return cfg
|
||||
}
|
||||
// we should only get here during tests because directive
|
||||
// actions typically skip the server blocks where we make
|
||||
// the configs
|
||||
cfg := &SiteConfig{Root: Root, TLS: new(caddytls.Config), IndexPages: staticfiles.DefaultIndexPages}
|
||||
cfg := &SiteConfig{
|
||||
Root: Root,
|
||||
TLS: &caddytls.Config{Manager: certmagic.NewDefault()},
|
||||
IndexPages: staticfiles.DefaultIndexPages,
|
||||
}
|
||||
ctx.saveConfig(key, cfg)
|
||||
return cfg
|
||||
}
|
||||
@@ -358,6 +447,43 @@ func (a Address) VHost() string {
|
||||
return a.Original
|
||||
}
|
||||
|
||||
// Normalize normalizes URL: turn scheme and host names into lower case
|
||||
func (a Address) Normalize() Address {
|
||||
path := a.Path
|
||||
if !CaseSensitivePath {
|
||||
path = strings.ToLower(path)
|
||||
}
|
||||
return Address{
|
||||
Original: a.Original,
|
||||
Scheme: strings.ToLower(a.Scheme),
|
||||
Host: strings.ToLower(a.Host),
|
||||
Port: a.Port,
|
||||
Path: path,
|
||||
}
|
||||
}
|
||||
|
||||
// Key is similar to String, just replaces scheme and host values with modified values.
|
||||
// Unlike String it doesn't add anything default (scheme, port, etc)
|
||||
func (a Address) Key() string {
|
||||
res := ""
|
||||
if a.Scheme != "" {
|
||||
res += a.Scheme + "://"
|
||||
}
|
||||
if a.Host != "" {
|
||||
res += a.Host
|
||||
}
|
||||
if a.Port != "" {
|
||||
if strings.HasPrefix(a.Original[len(res):], ":"+a.Port) {
|
||||
// insert port only if the original has its own explicit port
|
||||
res += ":" + a.Port
|
||||
}
|
||||
}
|
||||
if a.Path != "" {
|
||||
res += a.Path
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
// standardizeAddress parses an address string into a structured format with separate
|
||||
// scheme, host, port, and path portions, as well as the original input string.
|
||||
func standardizeAddress(str string) (Address, error) {
|
||||
@@ -485,6 +611,7 @@ var directives = []string{
|
||||
"startup", // TODO: Deprecate this directive
|
||||
"shutdown", // TODO: Deprecate this directive
|
||||
"on",
|
||||
"supervisor", // github.com/lucaslorentz/caddy-supervisor
|
||||
"request_id",
|
||||
"realip", // github.com/captncraig/caddy-realip
|
||||
"git", // github.com/abiosoft/caddy-git
|
||||
@@ -498,22 +625,23 @@ var directives = []string{
|
||||
"cache", // github.com/nicolasazrak/caddy-cache
|
||||
"rewrite",
|
||||
"ext",
|
||||
"minify", // github.com/hacdias/caddy-minify
|
||||
"gzip",
|
||||
"header",
|
||||
"geoip", // github.com/kodnaplakal/caddy-geoip
|
||||
"errors",
|
||||
"authz", // github.com/casbin/caddy-authz
|
||||
"filter", // github.com/echocat/caddy-filter
|
||||
"minify", // github.com/hacdias/caddy-minify
|
||||
"ipfilter", // github.com/pyed/ipfilter
|
||||
"ratelimit", // github.com/xuqingfeng/caddy-rate-limit
|
||||
"search", // github.com/pedronasser/caddy-search
|
||||
"expires", // github.com/epicagency/caddy-expires
|
||||
"forwardproxy", // github.com/caddyserver/forwardproxy
|
||||
"basicauth",
|
||||
"redir",
|
||||
"status",
|
||||
"cors", // github.com/captncraig/cors/caddy
|
||||
"nobots", // github.com/Xumeiquer/nobots
|
||||
"cors", // github.com/captncraig/cors/caddy
|
||||
"s3browser", // github.com/techknowlogick/caddy-s3browser
|
||||
"nobots", // github.com/Xumeiquer/nobots
|
||||
"mime",
|
||||
"login", // github.com/tarent/loginsrv/caddy
|
||||
"reauth", // github.com/freman/caddy-reauth
|
||||
@@ -532,18 +660,17 @@ var directives = []string{
|
||||
"fastcgi",
|
||||
"cgi", // github.com/jung-kurt/caddy-cgi
|
||||
"websocket",
|
||||
"filemanager", // github.com/hacdias/filemanager/caddy/filemanager
|
||||
"filebrowser", // github.com/filebrowser/caddy
|
||||
"webdav", // github.com/hacdias/caddy-webdav
|
||||
"markdown",
|
||||
"browse",
|
||||
"jekyll", // github.com/hacdias/filemanager/caddy/jekyll
|
||||
"hugo", // github.com/hacdias/filemanager/caddy/hugo
|
||||
"mailout", // github.com/SchumacherFM/mailout
|
||||
"awses", // github.com/miquella/caddy-awses
|
||||
"awslambda", // github.com/coopernurse/caddy-awslambda
|
||||
"grpc", // github.com/pieterlouw/caddy-grpc
|
||||
"gopkg", // github.com/zikes/gopkg
|
||||
"restic", // github.com/restic/caddy
|
||||
"wkd", // github.com/emersion/caddy-wkd
|
||||
}
|
||||
|
||||
const (
|
||||
|
||||
@@ -18,6 +18,10 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"sort"
|
||||
|
||||
"fmt"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/mholt/caddy/caddyfile"
|
||||
)
|
||||
@@ -147,7 +151,20 @@ func TestInspectServerBlocksWithCustomDefaultPort(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("Didn't expect an error, but got: %v", err)
|
||||
}
|
||||
addr := ctx.keysToSiteConfigs["localhost"].Addr
|
||||
localhostKey := "localhost"
|
||||
item, ok := ctx.keysToSiteConfigs[localhostKey]
|
||||
if !ok {
|
||||
availableKeys := make(sort.StringSlice, len(ctx.keysToSiteConfigs))
|
||||
i := 0
|
||||
for key := range ctx.keysToSiteConfigs {
|
||||
availableKeys[i] = fmt.Sprintf("'%s'", key)
|
||||
i++
|
||||
}
|
||||
availableKeys.Sort()
|
||||
t.Errorf("`%s` not found within registered keys, only these are available: %s", localhostKey, strings.Join(availableKeys, ", "))
|
||||
return
|
||||
}
|
||||
addr := item.Addr
|
||||
if addr.Port != Port {
|
||||
t.Errorf("Expected the port on the address to be set, but got: %#v", addr)
|
||||
}
|
||||
@@ -184,6 +201,64 @@ func TestInspectServerBlocksCaseInsensitiveKey(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeyNormalization(t *testing.T) {
|
||||
originalCaseSensitivePath := CaseSensitivePath
|
||||
defer func() {
|
||||
CaseSensitivePath = originalCaseSensitivePath
|
||||
}()
|
||||
CaseSensitivePath = true
|
||||
|
||||
caseSensitiveData := []struct {
|
||||
orig string
|
||||
res string
|
||||
}{
|
||||
{
|
||||
orig: "HTTP://A/ABCDEF",
|
||||
res: "http://a/ABCDEF",
|
||||
},
|
||||
{
|
||||
orig: "A/ABCDEF",
|
||||
res: "a/ABCDEF",
|
||||
},
|
||||
{
|
||||
orig: "A:2015/Port",
|
||||
res: "a:2015/Port",
|
||||
},
|
||||
}
|
||||
for _, item := range caseSensitiveData {
|
||||
v := normalizedKey(item.orig)
|
||||
if v != item.res {
|
||||
t.Errorf("Normalization of `%s` with CaseSensitivePath option set to true must be equal to `%s`, got `%s` instead", item.orig, item.res, v)
|
||||
}
|
||||
}
|
||||
|
||||
CaseSensitivePath = false
|
||||
caseInsensitiveData := []struct {
|
||||
orig string
|
||||
res string
|
||||
}{
|
||||
{
|
||||
orig: "HTTP://A/ABCDEF",
|
||||
res: "http://a/abcdef",
|
||||
},
|
||||
{
|
||||
orig: "A/ABCDEF",
|
||||
res: "a/abcdef",
|
||||
},
|
||||
{
|
||||
orig: "A:2015/Port",
|
||||
res: "a:2015/port",
|
||||
},
|
||||
}
|
||||
for _, item := range caseInsensitiveData {
|
||||
v := normalizedKey(item.orig)
|
||||
if v != item.res {
|
||||
t.Errorf("Normalization of `%s` with CaseSensitivePath option set to false must be equal to `%s`, got `%s` instead", item.orig, item.res, v)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestGetConfig(t *testing.T) {
|
||||
// case insensitivity for key
|
||||
con := caddy.NewTestController("http", "")
|
||||
@@ -201,6 +276,14 @@ func TestGetConfig(t *testing.T) {
|
||||
if cfg == cfg3 {
|
||||
t.Errorf("Expected different configs using when key is different; got %p and %p", cfg, cfg3)
|
||||
}
|
||||
|
||||
con.Key = "foo/foobar"
|
||||
cfg4 := GetConfig(con)
|
||||
con.Key = "foo/Foobar"
|
||||
cfg5 := GetConfig(con)
|
||||
if cfg4 == cfg5 {
|
||||
t.Errorf("Expected different cases in path to differentiate keys in general")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDirectivesList(t *testing.T) {
|
||||
|
||||
@@ -16,6 +16,10 @@ package httpserver
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
@@ -29,6 +33,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/mholt/caddy/caddytls"
|
||||
)
|
||||
|
||||
// requestReplacer is a strings.Replacer which is used to
|
||||
@@ -140,6 +145,14 @@ func canLogRequest(r *http.Request) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// unescapeBraces finds escaped braces in s and returns
|
||||
// a string with those braces unescaped.
|
||||
func unescapeBraces(s string) string {
|
||||
s = strings.Replace(s, "\\{", "{", -1)
|
||||
s = strings.Replace(s, "\\}", "}", -1)
|
||||
return s
|
||||
}
|
||||
|
||||
// Replace performs a replacement of values on s and returns
|
||||
// the string with the replaced values.
|
||||
func (r *replacer) Replace(s string) string {
|
||||
@@ -149,32 +162,59 @@ func (r *replacer) Replace(s string) string {
|
||||
}
|
||||
|
||||
result := ""
|
||||
Placeholders: // process each placeholder in sequence
|
||||
for {
|
||||
idxStart := strings.Index(s, "{")
|
||||
if idxStart == -1 {
|
||||
// no placeholder anymore
|
||||
break
|
||||
}
|
||||
idxEnd := strings.Index(s[idxStart:], "}")
|
||||
if idxEnd == -1 {
|
||||
// unpaired placeholder
|
||||
break
|
||||
}
|
||||
idxEnd += idxStart
|
||||
var idxStart, idxEnd int
|
||||
|
||||
// get a replacement
|
||||
placeholder := s[idxStart : idxEnd+1]
|
||||
idxOffset := 0
|
||||
for { // find first unescaped opening brace
|
||||
searchSpace := s[idxOffset:]
|
||||
idxStart = strings.Index(searchSpace, "{")
|
||||
if idxStart == -1 {
|
||||
// no more placeholders
|
||||
break Placeholders
|
||||
}
|
||||
if idxStart == 0 || searchSpace[idxStart-1] != '\\' {
|
||||
// preceding character is not an escape
|
||||
idxStart += idxOffset
|
||||
break
|
||||
}
|
||||
// the brace we found was escaped
|
||||
// search the rest of the string next
|
||||
idxOffset += idxStart + 1
|
||||
}
|
||||
|
||||
idxOffset = 0
|
||||
for { // find first unescaped closing brace
|
||||
searchSpace := s[idxStart+idxOffset:]
|
||||
idxEnd = strings.Index(searchSpace, "}")
|
||||
if idxEnd == -1 {
|
||||
// unpaired placeholder
|
||||
break Placeholders
|
||||
}
|
||||
if idxEnd == 0 || searchSpace[idxEnd-1] != '\\' {
|
||||
// preceding character is not an escape
|
||||
idxEnd += idxOffset + idxStart
|
||||
break
|
||||
}
|
||||
// the brace we found was escaped
|
||||
// search the rest of the string next
|
||||
idxOffset += idxEnd + 1
|
||||
}
|
||||
|
||||
// get a replacement for the unescaped placeholder
|
||||
placeholder := unescapeBraces(s[idxStart : idxEnd+1])
|
||||
replacement := r.getSubstitution(placeholder)
|
||||
|
||||
// append prefix + replacement
|
||||
result += s[:idxStart] + replacement
|
||||
// append unescaped prefix + replacement
|
||||
result += strings.TrimPrefix(unescapeBraces(s[:idxStart]), "\\") + replacement
|
||||
|
||||
// strip out scanned parts
|
||||
s = s[idxEnd+1:]
|
||||
}
|
||||
|
||||
// append unscanned parts
|
||||
return result + s
|
||||
return result + unescapeBraces(s)
|
||||
}
|
||||
|
||||
func roundDuration(d time.Duration) time.Duration {
|
||||
@@ -207,6 +247,15 @@ func round(d, r time.Duration) time.Duration {
|
||||
return d
|
||||
}
|
||||
|
||||
// getPeerCert returns peer certificate
|
||||
func (r *replacer) getPeerCert() *x509.Certificate {
|
||||
if r.request.TLS != nil && len(r.request.TLS.PeerCertificates) > 0 {
|
||||
return r.request.TLS.PeerCertificates[0]
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getSubstitution retrieves value from corresponding key
|
||||
func (r *replacer) getSubstitution(key string) string {
|
||||
// search custom replacements first
|
||||
@@ -319,10 +368,14 @@ func (r *replacer) getSubstitution(key string) string {
|
||||
return url.QueryEscape(r.request.URL.RequestURI())
|
||||
case "{when}":
|
||||
return now().Format(timeFormat)
|
||||
case "{when_iso_local}":
|
||||
return now().Format(timeFormatISO)
|
||||
case "{when_iso}":
|
||||
return now().UTC().Format(timeFormatISOUTC)
|
||||
case "{when_unix}":
|
||||
return strconv.FormatInt(now().Unix(), 10)
|
||||
case "{when_unix_ms}":
|
||||
return strconv.FormatInt(nanoToMilliseconds(now().UnixNano()), 10)
|
||||
case "{file}":
|
||||
_, file := path.Split(r.request.URL.Path)
|
||||
return file
|
||||
@@ -375,14 +428,110 @@ func (r *replacer) getSubstitution(key string) string {
|
||||
}
|
||||
elapsedDuration := time.Since(r.responseRecorder.start)
|
||||
return strconv.FormatInt(convertToMilliseconds(elapsedDuration), 10)
|
||||
case "{tls_protocol}":
|
||||
if r.request.TLS != nil {
|
||||
if name, err := caddytls.GetSupportedProtocolName(r.request.TLS.Version); err == nil {
|
||||
return name
|
||||
} else {
|
||||
return "tls" // this should never happen, but guard in case
|
||||
}
|
||||
}
|
||||
return r.emptyValue // because not using a secure channel
|
||||
case "{tls_cipher}":
|
||||
if r.request.TLS != nil {
|
||||
if name, err := caddytls.GetSupportedCipherName(r.request.TLS.CipherSuite); err == nil {
|
||||
return name
|
||||
} else {
|
||||
return "UNKNOWN" // this should never happen, but guard in case
|
||||
}
|
||||
}
|
||||
return r.emptyValue
|
||||
case "{tls_client_escaped_cert}":
|
||||
cert := r.getPeerCert()
|
||||
if cert != nil {
|
||||
pemBlock := pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: cert.Raw,
|
||||
}
|
||||
return url.QueryEscape(string(pem.EncodeToMemory(&pemBlock)))
|
||||
}
|
||||
return r.emptyValue
|
||||
case "{tls_client_fingerprint}":
|
||||
cert := r.getPeerCert()
|
||||
if cert != nil {
|
||||
return fmt.Sprintf("%x", sha256.Sum256(cert.Raw))
|
||||
}
|
||||
return r.emptyValue
|
||||
case "{tls_client_i_dn}":
|
||||
cert := r.getPeerCert()
|
||||
if cert != nil {
|
||||
return cert.Issuer.String()
|
||||
}
|
||||
return r.emptyValue
|
||||
case "{tls_client_raw_cert}":
|
||||
cert := r.getPeerCert()
|
||||
if cert != nil {
|
||||
return string(cert.Raw)
|
||||
}
|
||||
return r.emptyValue
|
||||
case "{tls_client_s_dn}":
|
||||
cert := r.getPeerCert()
|
||||
if cert != nil {
|
||||
return cert.Subject.String()
|
||||
}
|
||||
return r.emptyValue
|
||||
case "{tls_client_serial}":
|
||||
cert := r.getPeerCert()
|
||||
if cert != nil {
|
||||
return fmt.Sprintf("%x", cert.SerialNumber)
|
||||
}
|
||||
return r.emptyValue
|
||||
case "{tls_client_v_end}":
|
||||
cert := r.getPeerCert()
|
||||
if cert != nil {
|
||||
return cert.NotAfter.In(time.UTC).Format("Jan 02 15:04:05 2006 MST")
|
||||
}
|
||||
return r.emptyValue
|
||||
case "{tls_client_v_remain}":
|
||||
cert := r.getPeerCert()
|
||||
if cert != nil {
|
||||
now := time.Now().In(time.UTC)
|
||||
days := int64(cert.NotAfter.Sub(now).Seconds() / 86400)
|
||||
return strconv.FormatInt(days, 10)
|
||||
}
|
||||
return r.emptyValue
|
||||
case "{tls_client_v_start}":
|
||||
cert := r.getPeerCert()
|
||||
if cert != nil {
|
||||
return cert.NotBefore.Format("Jan 02 15:04:05 2006 MST")
|
||||
}
|
||||
return r.emptyValue
|
||||
default:
|
||||
// {labelN}
|
||||
if strings.HasPrefix(key, "{label") {
|
||||
nStr := key[6 : len(key)-1] // get the integer N in "{labelN}"
|
||||
n, err := strconv.Atoi(nStr)
|
||||
if err != nil || n < 1 {
|
||||
return r.emptyValue
|
||||
}
|
||||
labels := strings.Split(r.request.Host, ".")
|
||||
if n > len(labels) {
|
||||
return r.emptyValue
|
||||
}
|
||||
return labels[n-1]
|
||||
}
|
||||
}
|
||||
|
||||
return r.emptyValue
|
||||
}
|
||||
|
||||
//convertToMilliseconds returns the number of milliseconds in the given duration
|
||||
func nanoToMilliseconds(d int64) int64 {
|
||||
return d / 1e6
|
||||
}
|
||||
|
||||
// convertToMilliseconds returns the number of milliseconds in the given duration
|
||||
func convertToMilliseconds(d time.Duration) int64 {
|
||||
return d.Nanoseconds() / 1e6
|
||||
return nanoToMilliseconds(d.Nanoseconds())
|
||||
}
|
||||
|
||||
// Set sets key to value in the r.customReplacements map.
|
||||
@@ -392,6 +541,7 @@ func (r *replacer) Set(key, value string) {
|
||||
|
||||
const (
|
||||
timeFormat = "02/Jan/2006:15:04:05 -0700"
|
||||
timeFormatISO = "2006-01-02T15:04:05" // ISO 8601 with timezone to be assumed as local
|
||||
timeFormatISOUTC = "2006-01-02T15:04:05Z" // ISO 8601 with timezone to be assumed as UTC
|
||||
headerContentType = "Content-Type"
|
||||
contentTypeJSON = "application/json"
|
||||
|
||||
@@ -16,12 +16,21 @@ package httpserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mholt/caddy/caddytls"
|
||||
)
|
||||
|
||||
func TestNewReplacer(t *testing.T) {
|
||||
@@ -53,7 +62,7 @@ func TestReplace(t *testing.T) {
|
||||
recordRequest := NewResponseRecorder(w)
|
||||
reader := strings.NewReader(`{"username": "dennis"}`)
|
||||
|
||||
request, err := http.NewRequest("POST", "http://localhost/?foo=bar", reader)
|
||||
request, err := http.NewRequest("POST", "http://localhost.local/?foo=bar", reader)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to make request: %v", err)
|
||||
}
|
||||
@@ -77,7 +86,8 @@ func TestReplace(t *testing.T) {
|
||||
|
||||
old := now
|
||||
now = func() time.Time {
|
||||
return time.Date(2006, 1, 2, 15, 4, 5, 02, time.FixedZone("hardcoded", -7))
|
||||
// Note that the `-7` is seconds, not hours.
|
||||
return time.Date(2006, 1, 2, 15, 4, 5, 99999999, time.FixedZone("hardcoded", -7))
|
||||
}
|
||||
defer func() {
|
||||
now = old
|
||||
@@ -87,17 +97,19 @@ func TestReplace(t *testing.T) {
|
||||
expect string
|
||||
}{
|
||||
{"This hostname is {hostname}", "This hostname is " + hostname},
|
||||
{"This host is {host}.", "This host is localhost."},
|
||||
{"This host is {host}.", "This host is localhost.local."},
|
||||
{"This request method is {method}.", "This request method is POST."},
|
||||
{"The response status is {status}.", "The response status is 200."},
|
||||
{"{when}", "02/Jan/2006:15:04:05 +0000"},
|
||||
{"{when_iso}", "2006-01-02T15:04:12Z"},
|
||||
{"{when_iso_local}", "2006-01-02T15:04:05"},
|
||||
{"{when_unix}", "1136214252"},
|
||||
{"{when_unix_ms}", "1136214252099"},
|
||||
{"The Custom header is {>Custom}.", "The Custom header is foobarbaz."},
|
||||
{"The CustomAdd header is {>CustomAdd}.", "The CustomAdd header is caddy."},
|
||||
{"The Custom response header is {<Custom}.", "The Custom response header is CustomResponseHeader."},
|
||||
{"Bad {>Custom placeholder", "Bad {>Custom placeholder"},
|
||||
{"The request is {request}.", "The request is POST /?foo=bar HTTP/1.1\\r\\nHost: localhost\\r\\n" +
|
||||
{"The request is {request}.", "The request is POST /?foo=bar HTTP/1.1\\r\\nHost: localhost.local\\r\\n" +
|
||||
"Cookie: foo=bar; taste=delicious\\r\\nCustom: foobarbaz\\r\\nCustomadd: caddy\\r\\n" +
|
||||
"Shorterval: 1\\r\\n\\r\\n."},
|
||||
{"The cUsToM header is {>cUsToM}...", "The cUsToM header is foobarbaz..."},
|
||||
@@ -112,6 +124,9 @@ func TestReplace(t *testing.T) {
|
||||
{"Query string is {query}", "Query string is foo=bar"},
|
||||
{"Query string value for foo is {?foo}", "Query string value for foo is bar"},
|
||||
{"Missing query string argument is {?missing}", "Missing query string argument is "},
|
||||
{"{label1} {label2} {label3} {label4}", "localhost local - -"},
|
||||
{"Label with missing number is {label} or {labelQQ}", "Label with missing number is - or -"},
|
||||
{"\\{ 'hostname': '{hostname}' \\}", "{ 'hostname': '" + hostname + "' }"},
|
||||
}
|
||||
|
||||
for _, c := range testCases {
|
||||
@@ -144,6 +159,168 @@ func TestReplace(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestTlsReplace(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
recordRequest := NewResponseRecorder(w)
|
||||
|
||||
clientCertText := []byte(`-----BEGIN CERTIFICATE-----
|
||||
MIIB9jCCAV+gAwIBAgIBAjANBgkqhkiG9w0BAQsFADAYMRYwFAYDVQQDDA1DYWRk
|
||||
eSBUZXN0IENBMB4XDTE4MDcyNDIxMzUwNVoXDTI4MDcyMTIxMzUwNVowHTEbMBkG
|
||||
A1UEAwwSY2xpZW50LmxvY2FsZG9tYWluMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCB
|
||||
iQKBgQDFDEpzF0ew68teT3xDzcUxVFaTII+jXH1ftHXxxP4BEYBU4q90qzeKFneF
|
||||
z83I0nC0WAQ45ZwHfhLMYHFzHPdxr6+jkvKPASf0J2v2HDJuTM1bHBbik5Ls5eq+
|
||||
fVZDP8o/VHKSBKxNs8Goc2NTsr5b07QTIpkRStQK+RJALk4x9QIDAQABo0swSTAJ
|
||||
BgNVHRMEAjAAMAsGA1UdDwQEAwIHgDAaBgNVHREEEzARgglsb2NhbGhvc3SHBH8A
|
||||
AAEwEwYDVR0lBAwwCgYIKwYBBQUHAwIwDQYJKoZIhvcNAQELBQADgYEANSjz2Sk+
|
||||
eqp31wM9il1n+guTNyxJd+FzVAH+hCZE5K+tCgVDdVFUlDEHHbS/wqb2PSIoouLV
|
||||
3Q9fgDkiUod+uIK0IynzIKvw+Cjg+3nx6NQ0IM0zo8c7v398RzB4apbXKZyeeqUH
|
||||
9fNwfEi+OoXR6s+upSKobCmLGLGi9Na5s5g=
|
||||
-----END CERTIFICATE-----`)
|
||||
|
||||
block, _ := pem.Decode(clientCertText)
|
||||
if block == nil {
|
||||
t.Fatalf("failed to decode PEM certificate")
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to decode PEM certificate: %v", err)
|
||||
}
|
||||
|
||||
request := &http.Request{
|
||||
Method: "GET",
|
||||
Host: "foo.com",
|
||||
URL: &url.URL{
|
||||
Scheme: "https",
|
||||
Path: "/path/",
|
||||
Host: "foo.com",
|
||||
},
|
||||
Header: http.Header{},
|
||||
Proto: "HTTP/1.1",
|
||||
ProtoMajor: 1,
|
||||
ProtoMinor: 1,
|
||||
RemoteAddr: "192.0.2.1:1234",
|
||||
RequestURI: "https://foo.com/path/",
|
||||
TLS: &tls.ConnectionState{
|
||||
Version: tls.VersionTLS12,
|
||||
HandshakeComplete: true,
|
||||
ServerName: "foo.com",
|
||||
CipherSuite: tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
PeerCertificates: []*x509.Certificate{cert},
|
||||
},
|
||||
}
|
||||
|
||||
repl := NewReplacer(request, recordRequest, "-")
|
||||
|
||||
now := time.Now().In(time.UTC)
|
||||
days := int64(cert.NotAfter.Sub(now).Seconds() / 86400)
|
||||
pemBlock := pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: cert.Raw,
|
||||
}
|
||||
|
||||
protocol, _ := caddytls.GetSupportedProtocolName(request.TLS.Version)
|
||||
cipher, _ := caddytls.GetSupportedCipherName(request.TLS.CipherSuite)
|
||||
cEscapedCert := url.QueryEscape(string(pem.EncodeToMemory(&pemBlock)))
|
||||
cFingerprint := fmt.Sprintf("%x", sha256.Sum256(cert.Raw))
|
||||
cIDn := cert.Issuer.String()
|
||||
cRawCert := string(cert.Raw)
|
||||
cSDn := cert.Subject.String()
|
||||
cSerial := fmt.Sprintf("%x", cert.SerialNumber)
|
||||
cVEnd := cert.NotAfter.In(time.UTC).Format("Jan 02 15:04:05 2006 MST")
|
||||
cVRemain := strconv.FormatInt(days, 10)
|
||||
cVStart := cert.NotBefore.Format("Jan 02 15:04:05 2006 MST")
|
||||
|
||||
testCases := []struct {
|
||||
template string
|
||||
expect string
|
||||
}{
|
||||
{"{tls_protocol}", protocol},
|
||||
{"{tls_cipher}", cipher},
|
||||
{"{tls_client_escaped_cert}", cEscapedCert},
|
||||
{"{tls_client_fingerprint}", cFingerprint},
|
||||
{"{tls_client_i_dn}", cIDn},
|
||||
{"{tls_client_raw_cert}", cRawCert},
|
||||
{"{tls_client_s_dn}", cSDn},
|
||||
{"{tls_client_serial}", cSerial},
|
||||
{"{tls_client_v_end}", cVEnd},
|
||||
{"{tls_client_v_remain}", cVRemain},
|
||||
{"{tls_client_v_start}", cVStart},
|
||||
}
|
||||
|
||||
for _, c := range testCases {
|
||||
if expected, actual := c.expect, repl.Replace(c.template); expected != actual {
|
||||
t.Errorf("for template '%s', expected '%s', got '%s'", c.template, expected, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkReplace(b *testing.B) {
|
||||
w := httptest.NewRecorder()
|
||||
recordRequest := NewResponseRecorder(w)
|
||||
reader := strings.NewReader(`{"username": "dennis"}`)
|
||||
|
||||
request, err := http.NewRequest("POST", "http://localhost/?foo=bar", reader)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to make request: %v", err)
|
||||
}
|
||||
ctx := context.WithValue(request.Context(), OriginalURLCtxKey, *request.URL)
|
||||
request = request.WithContext(ctx)
|
||||
|
||||
request.Header.Set("Custom", "foobarbaz")
|
||||
request.Header.Set("ShorterVal", "1")
|
||||
repl := NewReplacer(request, recordRequest, "-")
|
||||
// add some headers after creating replacer
|
||||
request.Header.Set("CustomAdd", "caddy")
|
||||
request.Header.Set("Cookie", "foo=bar; taste=delicious")
|
||||
|
||||
// add some respons headers
|
||||
recordRequest.Header().Set("Custom", "CustomResponseHeader")
|
||||
|
||||
now = func() time.Time {
|
||||
// Note that the `-7` is seconds, not hours.
|
||||
return time.Date(2006, 1, 2, 15, 4, 5, 02, time.FixedZone("hardcoded", -7))
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
repl.Replace("This hostname is {hostname}")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkReplaceEscaped(b *testing.B) {
|
||||
w := httptest.NewRecorder()
|
||||
recordRequest := NewResponseRecorder(w)
|
||||
reader := strings.NewReader(`{"username": "dennis"}`)
|
||||
|
||||
request, err := http.NewRequest("POST", "http://localhost/?foo=bar", reader)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to make request: %v", err)
|
||||
}
|
||||
ctx := context.WithValue(request.Context(), OriginalURLCtxKey, *request.URL)
|
||||
request = request.WithContext(ctx)
|
||||
|
||||
request.Header.Set("Custom", "foobarbaz")
|
||||
request.Header.Set("ShorterVal", "1")
|
||||
repl := NewReplacer(request, recordRequest, "-")
|
||||
// add some headers after creating replacer
|
||||
request.Header.Set("CustomAdd", "caddy")
|
||||
request.Header.Set("Cookie", "foo=bar; taste=delicious")
|
||||
|
||||
// add some respons headers
|
||||
recordRequest.Header().Set("Custom", "CustomResponseHeader")
|
||||
|
||||
now = func() time.Time {
|
||||
// Note that the `-7` is seconds, not hours.
|
||||
return time.Date(2006, 1, 2, 15, 4, 5, 02, time.FixedZone("hardcoded", -7))
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
repl.Replace("\\{ 'hostname': '{hostname}' \\}")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseRecorderNil(t *testing.T) {
|
||||
|
||||
reader := strings.NewReader(`{"username": "dennis"}`)
|
||||
@@ -161,6 +338,7 @@ func TestResponseRecorderNil(t *testing.T) {
|
||||
|
||||
old := now
|
||||
now = func() time.Time {
|
||||
// Note that the `-7` is seconds, not hours.
|
||||
return time.Date(2006, 1, 2, 15, 4, 5, 02, time.FixedZone("hardcoded", -7))
|
||||
}
|
||||
defer func() {
|
||||
|
||||
@@ -36,6 +36,7 @@ import (
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/mholt/caddy/caddyhttp/staticfiles"
|
||||
"github.com/mholt/caddy/caddytls"
|
||||
"github.com/mholt/caddy/telemetry"
|
||||
)
|
||||
|
||||
// Server is the HTTP server implementation.
|
||||
@@ -273,16 +274,26 @@ func (s *Server) Listen() (net.Listener, error) {
|
||||
ln = tcpKeepAliveListener{TCPListener: tcpLn}
|
||||
}
|
||||
|
||||
cln := s.WrapListener(ln)
|
||||
|
||||
// Very important to return a concrete caddy.Listener
|
||||
// implementation for graceful restarts.
|
||||
return cln.(caddy.Listener), nil
|
||||
}
|
||||
|
||||
// WrapListener wraps ln in the listener middlewares configured
|
||||
// for this server.
|
||||
func (s *Server) WrapListener(ln net.Listener) net.Listener {
|
||||
if ln == nil {
|
||||
return nil
|
||||
}
|
||||
cln := ln.(caddy.Listener)
|
||||
for _, site := range s.sites {
|
||||
for _, m := range site.listenerMiddleware {
|
||||
cln = m(cln)
|
||||
}
|
||||
}
|
||||
|
||||
// Very important to return a concrete caddy.Listener
|
||||
// implementation for graceful restarts.
|
||||
return cln.(caddy.Listener), nil
|
||||
return cln
|
||||
}
|
||||
|
||||
// ListenPacket creates udp connection for QUIC if it is enabled,
|
||||
@@ -319,6 +330,9 @@ func (s *Server) Serve(ln net.Listener) error {
|
||||
}
|
||||
|
||||
err := s.Server.Serve(ln)
|
||||
if err == http.ErrServerClosed {
|
||||
err = nil // not an error worth reporting since closing a server is intentional
|
||||
}
|
||||
if s.quicServer != nil {
|
||||
s.quicServer.Close()
|
||||
}
|
||||
@@ -345,6 +359,16 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}()
|
||||
|
||||
// record the User-Agent string (with a cap on its length to mitigate attacks)
|
||||
ua := r.Header.Get("User-Agent")
|
||||
if len(ua) > 512 {
|
||||
ua = ua[:512]
|
||||
}
|
||||
uaHash := telemetry.FastHash([]byte(ua)) // this is a normalized field
|
||||
go telemetry.SetNested("http_user_agent", uaHash, ua)
|
||||
go telemetry.AppendUnique("http_user_agent_count", uaHash)
|
||||
go telemetry.Increment("http_request_count")
|
||||
|
||||
// copy the original, unchanged URL into the context
|
||||
// so it can be referenced by middlewares
|
||||
urlCopy := *r.URL
|
||||
@@ -388,24 +412,26 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) (int, error)
|
||||
|
||||
if vhost == nil {
|
||||
// check for ACME challenge even if vhost is nil;
|
||||
// could be a new host coming online soon
|
||||
if caddytls.HTTPChallengeHandler(w, r, "localhost") {
|
||||
// could be a new host coming online soon - choose any
|
||||
// vhost's cert manager configuration, I guess
|
||||
if len(s.sites) > 0 && s.sites[0].TLS.Manager.HandleHTTPChallenge(w, r) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// otherwise, log the error and write a message to the client
|
||||
remoteHost, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
remoteHost = r.RemoteAddr
|
||||
}
|
||||
WriteSiteNotFound(w, r) // don't add headers outside of this function
|
||||
WriteSiteNotFound(w, r) // don't add headers outside of this function (http.forwardproxy)
|
||||
log.Printf("[INFO] %s - No such site at %s (Remote: %s, Referer: %s)",
|
||||
hostname, s.Server.Addr, remoteHost, r.Header.Get("Referer"))
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// we still check for ACME challenge if the vhost exists,
|
||||
// because we must apply its HTTP challenge config settings
|
||||
if caddytls.HTTPChallengeHandler(w, r, vhost.ListenHost) {
|
||||
// because the HTTP challenge might be disabled by its config
|
||||
if vhost.TLS.Manager.HandleHTTPChallenge(w, r) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
@@ -416,19 +442,39 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) (int, error)
|
||||
r.URL = trimPathPrefix(r.URL, pathPrefix)
|
||||
}
|
||||
|
||||
// enforce strict host matching, which ensures that the SNI
|
||||
// value (if any), matches the Host header; essential for
|
||||
// sites that rely on TLS ClientAuth sharing a port with
|
||||
// sites that do not - if mismatched, close the connection
|
||||
if vhost.StrictHostMatching && r.TLS != nil &&
|
||||
strings.ToLower(r.TLS.ServerName) != strings.ToLower(hostname) {
|
||||
r.Close = true
|
||||
log.Printf("[ERROR] %s - strict host matching: SNI (%s) and HTTP Host (%s) values differ",
|
||||
vhost.Addr, r.TLS.ServerName, hostname)
|
||||
return http.StatusForbidden, nil
|
||||
}
|
||||
|
||||
return vhost.middlewareChain.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func trimPathPrefix(u *url.URL, prefix string) *url.URL {
|
||||
// We need to use URL.EscapedPath() when trimming the pathPrefix as
|
||||
// URL.Path is ambiguous about / or %2f - see docs. See #1927
|
||||
trimmed := strings.TrimPrefix(u.EscapedPath(), prefix)
|
||||
if !strings.HasPrefix(trimmed, "/") {
|
||||
trimmed = "/" + trimmed
|
||||
trimmedPath := strings.TrimPrefix(u.EscapedPath(), prefix)
|
||||
if !strings.HasPrefix(trimmedPath, "/") {
|
||||
trimmedPath = "/" + trimmedPath
|
||||
}
|
||||
trimmedURL, err := url.Parse(trimmed)
|
||||
// After trimming path reconstruct uri string with Query before parsing
|
||||
trimmedURI := trimmedPath
|
||||
if u.RawQuery != "" || u.ForceQuery == true {
|
||||
trimmedURI = trimmedPath + "?" + u.RawQuery
|
||||
}
|
||||
if u.Fragment != "" {
|
||||
trimmedURI = trimmedURI + "#" + u.Fragment
|
||||
}
|
||||
trimmedURL, err := url.Parse(trimmedURI)
|
||||
if err != nil {
|
||||
log.Printf("[ERROR] Unable to parse trimmed URL %s: %v", trimmed, err)
|
||||
log.Printf("[ERROR] Unable to parse trimmed URL %s: %v", trimmedURI, err)
|
||||
return u
|
||||
}
|
||||
return trimmedURL
|
||||
|
||||
@@ -129,88 +129,108 @@ func TestMakeHTTPServerWithTimeouts(t *testing.T) {
|
||||
|
||||
func TestTrimPathPrefix(t *testing.T) {
|
||||
for i, pt := range []struct {
|
||||
path string
|
||||
url string
|
||||
prefix string
|
||||
expected string
|
||||
shouldFail bool
|
||||
}{
|
||||
{
|
||||
path: "/my/path",
|
||||
url: "/my/path",
|
||||
prefix: "/my",
|
||||
expected: "/path",
|
||||
shouldFail: false,
|
||||
},
|
||||
{
|
||||
path: "/my/%2f/path",
|
||||
url: "/my/%2f/path",
|
||||
prefix: "/my",
|
||||
expected: "/%2f/path",
|
||||
shouldFail: false,
|
||||
},
|
||||
{
|
||||
path: "/my/path",
|
||||
url: "/my/path",
|
||||
prefix: "/my/",
|
||||
expected: "/path",
|
||||
shouldFail: false,
|
||||
},
|
||||
{
|
||||
path: "/my///path",
|
||||
url: "/my///path",
|
||||
prefix: "/my",
|
||||
expected: "/path",
|
||||
shouldFail: true,
|
||||
},
|
||||
{
|
||||
path: "/my///path",
|
||||
url: "/my///path",
|
||||
prefix: "/my",
|
||||
expected: "///path",
|
||||
shouldFail: false,
|
||||
},
|
||||
{
|
||||
path: "/my/path///slash",
|
||||
url: "/my/path///slash",
|
||||
prefix: "/my",
|
||||
expected: "/path///slash",
|
||||
shouldFail: false,
|
||||
},
|
||||
{
|
||||
path: "/my/%2f/path/%2f",
|
||||
url: "/my/%2f/path/%2f",
|
||||
prefix: "/my",
|
||||
expected: "/%2f/path/%2f",
|
||||
shouldFail: false,
|
||||
}, {
|
||||
path: "/my/%20/path",
|
||||
url: "/my/%20/path",
|
||||
prefix: "/my",
|
||||
expected: "/%20/path",
|
||||
shouldFail: false,
|
||||
}, {
|
||||
path: "/path",
|
||||
url: "/path",
|
||||
prefix: "",
|
||||
expected: "/path",
|
||||
shouldFail: false,
|
||||
}, {
|
||||
path: "/path/my/",
|
||||
url: "/path/my/",
|
||||
prefix: "/my",
|
||||
expected: "/path/my/",
|
||||
shouldFail: false,
|
||||
}, {
|
||||
path: "",
|
||||
url: "",
|
||||
prefix: "/my",
|
||||
expected: "/",
|
||||
shouldFail: false,
|
||||
}, {
|
||||
path: "/apath",
|
||||
url: "/apath",
|
||||
prefix: "",
|
||||
expected: "/apath",
|
||||
shouldFail: false,
|
||||
}, {
|
||||
url: "/my/path/page.php?akey=value",
|
||||
prefix: "/my",
|
||||
expected: "/path/page.php?akey=value",
|
||||
shouldFail: false,
|
||||
}, {
|
||||
url: "/my/path/page?key=value#fragment",
|
||||
prefix: "/my",
|
||||
expected: "/path/page?key=value#fragment",
|
||||
shouldFail: false,
|
||||
}, {
|
||||
url: "/my/path/page#fragment",
|
||||
prefix: "/my",
|
||||
expected: "/path/page#fragment",
|
||||
shouldFail: false,
|
||||
}, {
|
||||
url: "/my/apath?",
|
||||
prefix: "/my",
|
||||
expected: "/apath?",
|
||||
shouldFail: false,
|
||||
},
|
||||
} {
|
||||
|
||||
u, _ := url.Parse(pt.path)
|
||||
if got, want := trimPathPrefix(u, pt.prefix), pt.expected; got.EscapedPath() != want {
|
||||
u, _ := url.Parse(pt.url)
|
||||
if got, want := trimPathPrefix(u, pt.prefix), pt.expected; got.String() != want {
|
||||
if !pt.shouldFail {
|
||||
|
||||
t.Errorf("Test %d: Expected='%s', but was '%s' ", i, want, got.EscapedPath())
|
||||
t.Errorf("Test %d: Expected='%s', but was '%s' ", i, want, got.String())
|
||||
}
|
||||
} else if pt.shouldFail {
|
||||
t.Errorf("SHOULDFAIL Test %d: Expected='%s', and was '%s' but should fail", i, want, got.EscapedPath())
|
||||
t.Errorf("SHOULDFAIL Test %d: Expected='%s', and was '%s' but should fail", i, want, got.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -36,6 +36,16 @@ type SiteConfig struct {
|
||||
// TLS configuration
|
||||
TLS *caddytls.Config
|
||||
|
||||
// If true, the Host header in the HTTP request must
|
||||
// match the SNI value in the TLS handshake (if any).
|
||||
// This should be enabled whenever a site relies on
|
||||
// TLS client authentication, for example; or any time
|
||||
// you want to enforce that THIS site's TLS config
|
||||
// is used and not the TLS config of any other site
|
||||
// on the same listener. TODO: Check how relevant this
|
||||
// is with TLS 1.3.
|
||||
StrictHostMatching bool
|
||||
|
||||
// Uncompiled middleware stack
|
||||
middleware []Middleware
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ import (
|
||||
|
||||
"os"
|
||||
|
||||
"github.com/mholt/caddy/caddytls"
|
||||
"github.com/russross/blackfriday"
|
||||
)
|
||||
|
||||
@@ -448,6 +449,15 @@ func (c Context) AddLink(link string) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Returns either TLS protocol version if TLS used or empty string otherwise
|
||||
func (c Context) TLSVersion() (ret string) {
|
||||
if c.Req.TLS != nil {
|
||||
// Safe to ignore an error
|
||||
ret, _ = caddytls.GetSupportedProtocolName(c.Req.TLS.Version)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// buffer pool for .Include context actions
|
||||
var includeBufs = sync.Pool{
|
||||
New: func() interface{} {
|
||||
|
||||
@@ -16,6 +16,7 @@ package httpserver
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
@@ -277,7 +278,7 @@ func TestHostname(t *testing.T) {
|
||||
// // Test 3 - ipv6 without port and brackets
|
||||
// {"2001:4860:4860::8888", "google-public-dns-a.google.com."},
|
||||
// Test 4 - no hostname available
|
||||
{"1.1.1.1", "1.1.1.1"},
|
||||
{"0.0.0.0", "0.0.0.0"},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
@@ -922,3 +923,40 @@ func TestAddLink(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTlsVersion(t *testing.T) {
|
||||
for _, test := range []struct {
|
||||
tlsState *tls.ConnectionState
|
||||
expectedResult string
|
||||
}{
|
||||
{
|
||||
&tls.ConnectionState{Version: tls.VersionTLS10},
|
||||
"tls1.0",
|
||||
},
|
||||
{
|
||||
&tls.ConnectionState{Version: tls.VersionTLS11},
|
||||
"tls1.1",
|
||||
},
|
||||
{
|
||||
&tls.ConnectionState{Version: tls.VersionTLS12},
|
||||
"tls1.2",
|
||||
},
|
||||
// TLS not used
|
||||
{
|
||||
nil,
|
||||
"",
|
||||
},
|
||||
// Unsupported version
|
||||
{
|
||||
&tls.ConnectionState{Version: 0x0399},
|
||||
"",
|
||||
},
|
||||
} {
|
||||
context := getContextOrFail(t)
|
||||
context.Req.TLS = test.tlsState
|
||||
result := context.TLSVersion()
|
||||
if result != test.expectedResult {
|
||||
t.Errorf("Expected %s got %s", test.expectedResult, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,6 +67,10 @@ func (l Logger) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||
|
||||
// Write log entries
|
||||
for _, e := range rule.Entries {
|
||||
// Check if there is an exception to prevent log being written
|
||||
if !e.Log.ShouldLog(r.URL.Path) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Mask IP Address
|
||||
if e.Log.IPMaskExists {
|
||||
@@ -78,6 +82,7 @@ func (l Logger) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||
}
|
||||
}
|
||||
e.Log.Println(rep.Replace(e.Format))
|
||||
|
||||
}
|
||||
|
||||
return status, err
|
||||
|
||||
@@ -177,3 +177,85 @@ func TestMultiEntries(t *testing.T) {
|
||||
t.Errorf("Expected %q, but got %q", expect, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogExcept(t *testing.T) {
|
||||
tests := []struct {
|
||||
LogRules []Rule
|
||||
logPath string
|
||||
shouldLog bool
|
||||
}{
|
||||
{[]Rule{{
|
||||
PathScope: "/",
|
||||
Entries: []*Entry{{
|
||||
Log: &httpserver.Logger{
|
||||
|
||||
Exceptions: []string{"/soup"},
|
||||
},
|
||||
Format: DefaultLogFormat,
|
||||
}},
|
||||
}}, `/soup`, false},
|
||||
{[]Rule{{
|
||||
PathScope: "/",
|
||||
Entries: []*Entry{{
|
||||
Log: &httpserver.Logger{
|
||||
|
||||
Exceptions: []string{"/tart"},
|
||||
},
|
||||
Format: DefaultLogFormat,
|
||||
}},
|
||||
}}, `/soup`, true},
|
||||
{[]Rule{{
|
||||
PathScope: "/",
|
||||
Entries: []*Entry{{
|
||||
Log: &httpserver.Logger{
|
||||
|
||||
Exceptions: []string{"/soup"},
|
||||
},
|
||||
Format: DefaultLogFormat,
|
||||
}},
|
||||
}}, `/tomatosoup`, true},
|
||||
{[]Rule{{
|
||||
PathScope: "/",
|
||||
Entries: []*Entry{{
|
||||
Log: &httpserver.Logger{
|
||||
|
||||
Exceptions: []string{"/pie/"},
|
||||
},
|
||||
Format: DefaultLogFormat,
|
||||
}},
|
||||
// Check exception with a trailing slash does not match without
|
||||
}}, `/pie`, true},
|
||||
{[]Rule{{
|
||||
PathScope: "/",
|
||||
Entries: []*Entry{{
|
||||
Log: &httpserver.Logger{
|
||||
|
||||
Exceptions: []string{"/pie.php"},
|
||||
},
|
||||
Format: DefaultLogFormat,
|
||||
}},
|
||||
}}, `/pie`, true},
|
||||
{[]Rule{{
|
||||
PathScope: "/",
|
||||
Entries: []*Entry{{
|
||||
Log: &httpserver.Logger{
|
||||
|
||||
Exceptions: []string{"/pie"},
|
||||
},
|
||||
Format: DefaultLogFormat,
|
||||
}},
|
||||
// Check that a word without trailing slash will match a filename
|
||||
}}, `/pie.php`, false},
|
||||
}
|
||||
for i, test := range tests {
|
||||
for _, LogRule := range test.LogRules {
|
||||
for _, e := range LogRule.Entries {
|
||||
shouldLog := e.Log.ShouldLog(test.logPath)
|
||||
if shouldLog != test.shouldLog {
|
||||
t.Fatalf("Test %d expected shouldLog=%t but got shouldLog=%t,", i, test.shouldLog, shouldLog)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,7 +44,7 @@ func setup(c *caddy.Controller) error {
|
||||
|
||||
func logParse(c *caddy.Controller) ([]*Rule, error) {
|
||||
var rules []*Rule
|
||||
|
||||
var logExceptions []string
|
||||
for c.Next() {
|
||||
args := c.RemainingArgs()
|
||||
|
||||
@@ -91,6 +91,12 @@ func logParse(c *caddy.Controller) ([]*Rule, error) {
|
||||
|
||||
}
|
||||
|
||||
} else if what == "except" {
|
||||
|
||||
for i := 0; i < len(where); i++ {
|
||||
logExceptions = append(logExceptions, where[i])
|
||||
}
|
||||
|
||||
} else if httpserver.IsLogRollerSubdirective(what) {
|
||||
|
||||
if err := httpserver.ParseRoller(logRoller, what, where...); err != nil {
|
||||
@@ -133,6 +139,7 @@ func logParse(c *caddy.Controller) ([]*Rule, error) {
|
||||
V4ipMask: ip4Mask,
|
||||
V6ipMask: ip6Mask,
|
||||
IPMaskExists: ipMaskExists,
|
||||
Exceptions: logExceptions,
|
||||
},
|
||||
Format: format,
|
||||
})
|
||||
|
||||
@@ -142,7 +142,7 @@ func (md Markdown) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error
|
||||
case err == nil: // nop
|
||||
case os.IsPermission(err):
|
||||
return http.StatusForbidden, err
|
||||
case os.IsExist(err):
|
||||
case os.IsNotExist(err):
|
||||
return http.StatusNotFound, nil
|
||||
default: // did we run out of FD?
|
||||
return http.StatusInternalServerError, err
|
||||
|
||||
@@ -183,6 +183,8 @@ type Header struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
var roundRobinPolicier RoundRobin
|
||||
|
||||
// Select selects the host based on hashing the header value
|
||||
func (r *Header) Select(pool HostPool, request *http.Request) *UpstreamHost {
|
||||
if r.Name == "" {
|
||||
@@ -190,7 +192,8 @@ func (r *Header) Select(pool HostPool, request *http.Request) *UpstreamHost {
|
||||
}
|
||||
val := request.Header.Get(r.Name)
|
||||
if val == "" {
|
||||
return nil
|
||||
// fallback to RoundRobin policy in case no Header in request
|
||||
return roundRobinPolicier.Select(pool, request)
|
||||
}
|
||||
return hostByHashing(pool, val)
|
||||
}
|
||||
|
||||
@@ -320,21 +320,25 @@ func TestUriPolicy(t *testing.T) {
|
||||
func TestHeaderPolicy(t *testing.T) {
|
||||
pool := testPool()
|
||||
tests := []struct {
|
||||
Name string
|
||||
Policy *Header
|
||||
RequestHeaderName string
|
||||
RequestHeaderValue string
|
||||
NilHost bool
|
||||
HostIndex int
|
||||
}{
|
||||
{&Header{""}, "", "", true, 0},
|
||||
{&Header{""}, "Affinity", "somevalue", true, 0},
|
||||
{&Header{""}, "Affinity", "", true, 0},
|
||||
{"empty config", &Header{""}, "", "", true, 0},
|
||||
{"empty config+header+value", &Header{""}, "Affinity", "somevalue", true, 0},
|
||||
{"empty config+header", &Header{""}, "Affinity", "", true, 0},
|
||||
|
||||
{&Header{"Affinity"}, "", "", true, 0},
|
||||
{&Header{"Affinity"}, "Affinity", "somevalue", false, 1},
|
||||
{&Header{"Affinity"}, "Affinity", "somevalue2", false, 0},
|
||||
{&Header{"Affinity"}, "Affinity", "somevalue3", false, 2},
|
||||
{&Header{"Affinity"}, "Affinity", "", true, 0},
|
||||
{"no header(fallback to roundrobin)", &Header{"Affinity"}, "", "", false, 1},
|
||||
{"no header(fallback to roundrobin)", &Header{"Affinity"}, "", "", false, 2},
|
||||
{"no header(fallback to roundrobin)", &Header{"Affinity"}, "", "", false, 0},
|
||||
|
||||
{"hash route to host", &Header{"Affinity"}, "Affinity", "somevalue", false, 1},
|
||||
{"hash route to host", &Header{"Affinity"}, "Affinity", "somevalue2", false, 0},
|
||||
{"hash route to host", &Header{"Affinity"}, "Affinity", "somevalue3", false, 2},
|
||||
{"hash route with empty value", &Header{"Affinity"}, "Affinity", "", false, 1},
|
||||
}
|
||||
|
||||
for idx, test := range tests {
|
||||
|
||||
@@ -47,6 +47,12 @@ type Upstream interface {
|
||||
// Checks if subpath is not an ignored path
|
||||
AllowedPath(string) bool
|
||||
|
||||
// Gets the duration of the headstart the first
|
||||
// connection is given in the Go standard library's
|
||||
// implementation of "Happy Eyeballs" when DualStack
|
||||
// is enabled in net.Dialer.
|
||||
GetFallbackDelay() time.Duration
|
||||
|
||||
// Gets how long to try selecting upstream hosts
|
||||
// in the case of cascading failures.
|
||||
GetTryDuration() time.Duration
|
||||
@@ -58,6 +64,10 @@ type Upstream interface {
|
||||
// Gets the number of upstream hosts.
|
||||
GetHostCount() int
|
||||
|
||||
// Gets how long to wait before timing out
|
||||
// the request
|
||||
GetTimeout() time.Duration
|
||||
|
||||
// Stops the upstream from proxying requests to shutdown goroutines cleanly.
|
||||
Stop() error
|
||||
}
|
||||
@@ -187,7 +197,12 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||
if nameURL, err := url.Parse(host.Name); err == nil {
|
||||
outreq.Host = nameURL.Host
|
||||
if proxy == nil {
|
||||
proxy = NewSingleHostReverseProxy(nameURL, host.WithoutPathPrefix, http.DefaultMaxIdleConnsPerHost)
|
||||
proxy = NewSingleHostReverseProxy(nameURL,
|
||||
host.WithoutPathPrefix,
|
||||
http.DefaultMaxIdleConnsPerHost,
|
||||
upstream.GetTimeout(),
|
||||
upstream.GetFallbackDelay(),
|
||||
)
|
||||
}
|
||||
|
||||
// use upstream credentials by default
|
||||
@@ -247,6 +262,10 @@ func (p Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||
return http.StatusRequestEntityTooLarge, backendErr
|
||||
}
|
||||
|
||||
if backendErr == context.Canceled {
|
||||
return CustomStatusContextCancelled, backendErr
|
||||
}
|
||||
|
||||
// failover; remember this failure for some time if
|
||||
// request failure counting is enabled
|
||||
timeout := host.FailTimeout
|
||||
@@ -382,3 +401,5 @@ func mutateHeadersByRules(headers, rules http.Header, repl httpserver.Replacer)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const CustomStatusContextCancelled = 499
|
||||
|
||||
+147
-38
@@ -122,7 +122,7 @@ func TestReverseProxy(t *testing.T) {
|
||||
// set up proxy
|
||||
p := &Proxy{
|
||||
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
|
||||
Upstreams: []Upstream{newFakeUpstream(backend.URL, false)},
|
||||
Upstreams: []Upstream{newFakeUpstream(backend.URL, false, 30*time.Second, 300*time.Millisecond)},
|
||||
}
|
||||
|
||||
// Create the fake request body.
|
||||
@@ -202,7 +202,7 @@ func TestReverseProxyInsecureSkipVerify(t *testing.T) {
|
||||
// set up proxy
|
||||
p := &Proxy{
|
||||
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
|
||||
Upstreams: []Upstream{newFakeUpstream(backend.URL, true)},
|
||||
Upstreams: []Upstream{newFakeUpstream(backend.URL, true, 30*time.Second, 300*time.Millisecond)},
|
||||
}
|
||||
|
||||
// create request and response recorder
|
||||
@@ -287,6 +287,32 @@ func TestReverseProxyMaxConnLimit(t *testing.T) {
|
||||
jobs.Wait()
|
||||
}
|
||||
|
||||
func TestReverseProxyTimeout(t *testing.T) {
|
||||
timeout := 2 * time.Second
|
||||
fallbackDelay := 300 * time.Millisecond
|
||||
errorMargin := 100 * time.Millisecond
|
||||
log.SetOutput(ioutil.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
|
||||
// set up proxy
|
||||
p := &Proxy{
|
||||
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
|
||||
Upstreams: []Upstream{newFakeUpstream("https://8.8.8.8", true, timeout, fallbackDelay)},
|
||||
}
|
||||
|
||||
// create request and response recorder
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
start := time.Now()
|
||||
p.ServeHTTP(w, r)
|
||||
took := time.Since(start)
|
||||
|
||||
if took > timeout+errorMargin {
|
||||
t.Errorf("Expected timeout ~ %v but got %v", timeout, took)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebSocketReverseProxyNonHijackerPanic(t *testing.T) {
|
||||
// Capture the expected panic
|
||||
defer func() {
|
||||
@@ -301,7 +327,7 @@ func TestWebSocketReverseProxyNonHijackerPanic(t *testing.T) {
|
||||
defer wsNop.Close()
|
||||
|
||||
// Get proxy to use for the test
|
||||
p := newWebSocketTestProxy(wsNop.URL, false)
|
||||
p := newWebSocketTestProxy(wsNop.URL, false, 30*time.Second)
|
||||
|
||||
// Create client request
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
@@ -331,7 +357,7 @@ func TestWebSocketReverseProxyBackendShutDown(t *testing.T) {
|
||||
}()
|
||||
|
||||
// Get proxy to use for the test
|
||||
p := newWebSocketTestProxy(backend.URL, false)
|
||||
p := newWebSocketTestProxy(backend.URL, false, 30*time.Second)
|
||||
backendProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
p.ServeHTTP(w, r)
|
||||
}))
|
||||
@@ -360,7 +386,7 @@ func TestWebSocketReverseProxyServeHTTPHandler(t *testing.T) {
|
||||
defer wsNop.Close()
|
||||
|
||||
// Get proxy to use for the test
|
||||
p := newWebSocketTestProxy(wsNop.URL, false)
|
||||
p := newWebSocketTestProxy(wsNop.URL, false, 30*time.Second)
|
||||
|
||||
// Create client request
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
@@ -407,7 +433,7 @@ func TestWebSocketReverseProxyFromWSClient(t *testing.T) {
|
||||
defer wsEcho.Close()
|
||||
|
||||
// Get proxy to use for the test
|
||||
p := newWebSocketTestProxy(wsEcho.URL, false)
|
||||
p := newWebSocketTestProxy(wsEcho.URL, false, 30*time.Second)
|
||||
|
||||
// This is a full end-end test, so the proxy handler
|
||||
// has to be part of a server listening on a port. Our
|
||||
@@ -452,7 +478,7 @@ func TestWebSocketReverseProxyFromWSSClient(t *testing.T) {
|
||||
}))
|
||||
defer wsEcho.Close()
|
||||
|
||||
p := newWebSocketTestProxy(wsEcho.URL, true)
|
||||
p := newWebSocketTestProxy(wsEcho.URL, true, 30*time.Second)
|
||||
|
||||
echoProxy := newTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
p.ServeHTTP(w, r)
|
||||
@@ -528,7 +554,7 @@ func TestUnixSocketProxy(t *testing.T) {
|
||||
defer ts.Close()
|
||||
|
||||
url := strings.Replace(ts.URL, "http://", "unix:", 1)
|
||||
p := newWebSocketTestProxy(url, false)
|
||||
p := newWebSocketTestProxy(url, false, 30*time.Second)
|
||||
|
||||
echoProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
p.ServeHTTP(w, r)
|
||||
@@ -686,7 +712,7 @@ func TestUpstreamHeadersUpdate(t *testing.T) {
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
upstream := newFakeUpstream(backend.URL, false)
|
||||
upstream := newFakeUpstream(backend.URL, false, 30*time.Second, 300*time.Millisecond)
|
||||
upstream.host.UpstreamHeaders = http.Header{
|
||||
"Connection": {"{>Connection}"},
|
||||
"Upgrade": {"{>Upgrade}"},
|
||||
@@ -753,7 +779,7 @@ func TestDownstreamHeadersUpdate(t *testing.T) {
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
upstream := newFakeUpstream(backend.URL, false)
|
||||
upstream := newFakeUpstream(backend.URL, false, 30*time.Second, 300*time.Millisecond)
|
||||
upstream.host.DownstreamHeaders = http.Header{
|
||||
"+Merge-Me": {"Merge-Value"},
|
||||
"+Add-Me": {"Add-Value"},
|
||||
@@ -893,7 +919,7 @@ func TestHostSimpleProxyNoHeaderForward(t *testing.T) {
|
||||
// set up proxy
|
||||
p := &Proxy{
|
||||
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
|
||||
Upstreams: []Upstream{newFakeUpstream(backend.URL, false)},
|
||||
Upstreams: []Upstream{newFakeUpstream(backend.URL, false, 30*time.Second, 300*time.Millisecond)},
|
||||
}
|
||||
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
@@ -913,6 +939,67 @@ func TestHostSimpleProxyNoHeaderForward(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxyTransparentHeaders(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
remoteAddr string
|
||||
forwardedForHeader string
|
||||
expected []string
|
||||
}{
|
||||
{"No header", "192.168.0.1:80", "", []string{"192.168.0.1"}},
|
||||
{"Existing", "192.168.0.1:80", "1.1.1.1, 2.2.2.2", []string{"1.1.1.1, 2.2.2.2, 192.168.0.1"}},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
testReverseProxyTransparentHeaders(t, tc.remoteAddr, tc.forwardedForHeader, tc.expected)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testReverseProxyTransparentHeaders(t *testing.T, remoteAddr, forwardedForHeader string, expected []string) {
|
||||
// Arrange
|
||||
log.SetOutput(ioutil.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
|
||||
var actualHeaders http.Header
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
actualHeaders = r.Header
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
config := "proxy / " + backend.URL + " {\n transparent \n}"
|
||||
|
||||
// make proxy
|
||||
upstreams, err := NewStaticUpstreams(caddyfile.NewDispenser("Testfile", strings.NewReader(config)), "")
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error. Got: %s", err.Error())
|
||||
}
|
||||
|
||||
// set up proxy
|
||||
p := &Proxy{
|
||||
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
|
||||
Upstreams: upstreams,
|
||||
}
|
||||
|
||||
// create request and response recorder
|
||||
r := httptest.NewRequest("GET", backend.URL, nil)
|
||||
r.RemoteAddr = remoteAddr
|
||||
if forwardedForHeader != "" {
|
||||
r.Header.Set("X-Forwarded-For", forwardedForHeader)
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Act
|
||||
p.ServeHTTP(w, r)
|
||||
|
||||
// Assert
|
||||
if got := actualHeaders["X-Forwarded-For"]; !reflect.DeepEqual(got, expected) {
|
||||
t.Errorf("Transparent proxy response does not contain expected %v header: expect %v, but got %v",
|
||||
"X-Forwarded-For", expected, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHostHeaderReplacedUsingForward(t *testing.T) {
|
||||
var requestHost string
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -921,7 +1008,7 @@ func TestHostHeaderReplacedUsingForward(t *testing.T) {
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
upstream := newFakeUpstream(backend.URL, false)
|
||||
upstream := newFakeUpstream(backend.URL, false, 30*time.Second, 300*time.Millisecond)
|
||||
proxyHostHeader := "test2.com"
|
||||
upstream.host.UpstreamHeaders = http.Header{"Host": []string{proxyHostHeader}}
|
||||
// set up proxy
|
||||
@@ -943,11 +1030,22 @@ func TestHostHeaderReplacedUsingForward(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBasicAuth(t *testing.T) {
|
||||
basicAuthTestcase(t, nil, nil)
|
||||
basicAuthTestcase(t, nil, url.UserPassword("username", "password"))
|
||||
basicAuthTestcase(t, url.UserPassword("usename", "password"), nil)
|
||||
basicAuthTestcase(t, url.UserPassword("unused", "unused"),
|
||||
url.UserPassword("username", "password"))
|
||||
testCases := []struct {
|
||||
name string
|
||||
upstreamUser *url.Userinfo
|
||||
clientUser *url.Userinfo
|
||||
}{
|
||||
{"Nil Both", nil, nil},
|
||||
{"Nil Upstream User", nil, url.UserPassword("username", "password")},
|
||||
{"Nil Client User", url.UserPassword("usename", "password"), nil},
|
||||
{"Both Provided", url.UserPassword("unused", "unused"),
|
||||
url.UserPassword("username", "password")},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
basicAuthTestcase(t, tc.upstreamUser, tc.clientUser)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func basicAuthTestcase(t *testing.T, upstreamUser, clientUser *url.Userinfo) {
|
||||
@@ -972,7 +1070,7 @@ func basicAuthTestcase(t *testing.T, upstreamUser, clientUser *url.Userinfo) {
|
||||
|
||||
p := &Proxy{
|
||||
Next: httpserver.EmptyNext,
|
||||
Upstreams: []Upstream{newFakeUpstream(backURL.String(), false)},
|
||||
Upstreams: []Upstream{newFakeUpstream(backURL.String(), false, 30*time.Second, 300*time.Millisecond)},
|
||||
}
|
||||
r, err := http.NewRequest("GET", "/foo", nil)
|
||||
if err != nil {
|
||||
@@ -1107,7 +1205,7 @@ func TestProxyDirectorURL(t *testing.T) {
|
||||
continue
|
||||
}
|
||||
|
||||
NewSingleHostReverseProxy(targetURL, c.without, 0).Director(req)
|
||||
NewSingleHostReverseProxy(targetURL, c.without, 0, 30*time.Second, 300*time.Millisecond).Director(req)
|
||||
if expect, got := c.expectURL, req.URL.String(); expect != got {
|
||||
t.Errorf("case %d url not equal: expect %q, but got %q",
|
||||
i, expect, got)
|
||||
@@ -1254,7 +1352,7 @@ func TestCancelRequest(t *testing.T) {
|
||||
// set up proxy
|
||||
p := &Proxy{
|
||||
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
|
||||
Upstreams: []Upstream{newFakeUpstream(backend.URL, false)},
|
||||
Upstreams: []Upstream{newFakeUpstream(backend.URL, false, 30*time.Second, 300*time.Millisecond)},
|
||||
}
|
||||
|
||||
// setup request with cancel ctx
|
||||
@@ -1271,7 +1369,7 @@ func TestCancelRequest(t *testing.T) {
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
status, err := p.ServeHTTP(rec, req)
|
||||
expectedStatus, expectErr := http.StatusBadGateway, context.Canceled
|
||||
expectedStatus, expectErr := CustomStatusContextCancelled, context.Canceled
|
||||
if status != expectedStatus || err != expectErr {
|
||||
t.Errorf("expect proxy handle return status[%d] with error[%v], but got status[%d] with error[%v]",
|
||||
expectedStatus, expectErr, status, err)
|
||||
@@ -1303,14 +1401,16 @@ func (r *noopReader) Read(b []byte) (int, error) {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func newFakeUpstream(name string, insecure bool) *fakeUpstream {
|
||||
func newFakeUpstream(name string, insecure bool, timeout, fallbackDelay time.Duration) *fakeUpstream {
|
||||
uri, _ := url.Parse(name)
|
||||
u := &fakeUpstream{
|
||||
name: name,
|
||||
from: "/",
|
||||
name: name,
|
||||
from: "/",
|
||||
timeout: timeout,
|
||||
fallbackDelay: fallbackDelay,
|
||||
host: &UpstreamHost{
|
||||
Name: name,
|
||||
ReverseProxy: NewSingleHostReverseProxy(uri, "", http.DefaultMaxIdleConnsPerHost),
|
||||
ReverseProxy: NewSingleHostReverseProxy(uri, "", http.DefaultMaxIdleConnsPerHost, timeout, fallbackDelay),
|
||||
},
|
||||
}
|
||||
if insecure {
|
||||
@@ -1320,10 +1420,12 @@ func newFakeUpstream(name string, insecure bool) *fakeUpstream {
|
||||
}
|
||||
|
||||
type fakeUpstream struct {
|
||||
name string
|
||||
host *UpstreamHost
|
||||
from string
|
||||
without string
|
||||
name string
|
||||
host *UpstreamHost
|
||||
from string
|
||||
without string
|
||||
timeout time.Duration
|
||||
fallbackDelay time.Duration
|
||||
}
|
||||
|
||||
func (u *fakeUpstream) From() string {
|
||||
@@ -1338,15 +1440,17 @@ func (u *fakeUpstream) Select(r *http.Request) *UpstreamHost {
|
||||
}
|
||||
u.host = &UpstreamHost{
|
||||
Name: u.name,
|
||||
ReverseProxy: NewSingleHostReverseProxy(uri, u.without, http.DefaultMaxIdleConnsPerHost),
|
||||
ReverseProxy: NewSingleHostReverseProxy(uri, u.without, http.DefaultMaxIdleConnsPerHost, u.GetTimeout(), u.GetFallbackDelay()),
|
||||
}
|
||||
}
|
||||
return u.host
|
||||
}
|
||||
|
||||
func (u *fakeUpstream) AllowedPath(requestPath string) bool { return true }
|
||||
func (u *fakeUpstream) GetFallbackDelay() time.Duration { return 300 * time.Millisecond }
|
||||
func (u *fakeUpstream) GetTryDuration() time.Duration { return 1 * time.Second }
|
||||
func (u *fakeUpstream) GetTryInterval() time.Duration { return 250 * time.Millisecond }
|
||||
func (u *fakeUpstream) GetTimeout() time.Duration { return u.timeout }
|
||||
func (u *fakeUpstream) GetHostCount() int { return 1 }
|
||||
func (u *fakeUpstream) Stop() error { return nil }
|
||||
|
||||
@@ -1354,13 +1458,14 @@ func (u *fakeUpstream) Stop() error { return nil }
|
||||
// redirect to the specified backendAddr. The function
|
||||
// also sets up the rules/environment for testing WebSocket
|
||||
// proxy.
|
||||
func newWebSocketTestProxy(backendAddr string, insecure bool) *Proxy {
|
||||
func newWebSocketTestProxy(backendAddr string, insecure bool, timeout time.Duration) *Proxy {
|
||||
return &Proxy{
|
||||
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
|
||||
Upstreams: []Upstream{&fakeWsUpstream{
|
||||
name: backendAddr,
|
||||
without: "",
|
||||
insecure: insecure,
|
||||
timeout: timeout,
|
||||
}},
|
||||
}
|
||||
}
|
||||
@@ -1368,14 +1473,16 @@ func newWebSocketTestProxy(backendAddr string, insecure bool) *Proxy {
|
||||
func newPrefixedWebSocketTestProxy(backendAddr string, prefix string) *Proxy {
|
||||
return &Proxy{
|
||||
Next: httpserver.EmptyNext, // prevents panic in some cases when test fails
|
||||
Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr, without: prefix}},
|
||||
Upstreams: []Upstream{&fakeWsUpstream{name: backendAddr, without: prefix, timeout: 30 * time.Second}},
|
||||
}
|
||||
}
|
||||
|
||||
type fakeWsUpstream struct {
|
||||
name string
|
||||
without string
|
||||
insecure bool
|
||||
name string
|
||||
without string
|
||||
insecure bool
|
||||
timeout time.Duration
|
||||
fallbackDelay time.Duration
|
||||
}
|
||||
|
||||
func (u *fakeWsUpstream) From() string {
|
||||
@@ -1386,7 +1493,7 @@ func (u *fakeWsUpstream) Select(r *http.Request) *UpstreamHost {
|
||||
uri, _ := url.Parse(u.name)
|
||||
host := &UpstreamHost{
|
||||
Name: u.name,
|
||||
ReverseProxy: NewSingleHostReverseProxy(uri, u.without, http.DefaultMaxIdleConnsPerHost),
|
||||
ReverseProxy: NewSingleHostReverseProxy(uri, u.without, http.DefaultMaxIdleConnsPerHost, u.GetTimeout(), u.GetFallbackDelay()),
|
||||
UpstreamHeaders: http.Header{
|
||||
"Connection": {"{>Connection}"},
|
||||
"Upgrade": {"{>Upgrade}"}},
|
||||
@@ -1398,8 +1505,10 @@ func (u *fakeWsUpstream) Select(r *http.Request) *UpstreamHost {
|
||||
}
|
||||
|
||||
func (u *fakeWsUpstream) AllowedPath(requestPath string) bool { return true }
|
||||
func (u *fakeWsUpstream) GetFallbackDelay() time.Duration { return 300 * time.Millisecond }
|
||||
func (u *fakeWsUpstream) GetTryDuration() time.Duration { return 1 * time.Second }
|
||||
func (u *fakeWsUpstream) GetTryInterval() time.Duration { return 250 * time.Millisecond }
|
||||
func (u *fakeWsUpstream) GetTimeout() time.Duration { return u.timeout }
|
||||
func (u *fakeWsUpstream) GetHostCount() int { return 1 }
|
||||
func (u *fakeWsUpstream) Stop() error { return nil }
|
||||
|
||||
@@ -1445,7 +1554,7 @@ func BenchmarkProxy(b *testing.B) {
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
upstream := newFakeUpstream(backend.URL, false)
|
||||
upstream := newFakeUpstream(backend.URL, false, 30*time.Second, 300*time.Millisecond)
|
||||
upstream.host.UpstreamHeaders = http.Header{
|
||||
"Hostname": {"{hostname}"},
|
||||
"Host": {"{host}"},
|
||||
@@ -1488,7 +1597,7 @@ func TestChunkedWebSocketReverseProxy(t *testing.T) {
|
||||
defer wsNop.Close()
|
||||
|
||||
// Get proxy to use for the test
|
||||
p := newWebSocketTestProxy(wsNop.URL, false)
|
||||
p := newWebSocketTestProxy(wsNop.URL, false, 30*time.Second)
|
||||
|
||||
// Create client request
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
|
||||
@@ -48,6 +48,7 @@ var (
|
||||
defaultDialer = &net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
DualStack: true,
|
||||
}
|
||||
|
||||
bufferPool = sync.Pool{New: createBuffer}
|
||||
@@ -85,7 +86,6 @@ type ReverseProxy struct {
|
||||
Director func(*http.Request)
|
||||
|
||||
// The transport used to perform proxy requests.
|
||||
// If nil, http.DefaultTransport is used.
|
||||
Transport http.RoundTripper
|
||||
|
||||
// FlushInterval specifies the flush interval
|
||||
@@ -94,6 +94,10 @@ type ReverseProxy struct {
|
||||
// If zero, no periodic flushing is done.
|
||||
FlushInterval time.Duration
|
||||
|
||||
// dialer is used when values from the
|
||||
// defaultDialer need to be overridden per Proxy
|
||||
dialer *net.Dialer
|
||||
|
||||
srvResolver srvResolver
|
||||
}
|
||||
|
||||
@@ -103,13 +107,13 @@ type ReverseProxy struct {
|
||||
// What we need is just the path, so if "unix:/var/run/www.socket"
|
||||
// was the proxy directive, the parsed hostName would be
|
||||
// "unix:///var/run/www.socket", hence the ambiguous trimming.
|
||||
func socketDial(hostName string) func(network, addr string) (conn net.Conn, err error) {
|
||||
func socketDial(hostName string, timeout time.Duration) func(network, addr string) (conn net.Conn, err error) {
|
||||
return func(network, addr string) (conn net.Conn, err error) {
|
||||
return net.Dial("unix", hostName[len("unix://"):])
|
||||
return net.DialTimeout("unix", hostName[len("unix://"):], timeout)
|
||||
}
|
||||
}
|
||||
|
||||
func (rp *ReverseProxy) srvDialerFunc(locator string) func(network, addr string) (conn net.Conn, err error) {
|
||||
func (rp *ReverseProxy) srvDialerFunc(locator string, timeout time.Duration) func(network, addr string) (conn net.Conn, err error) {
|
||||
service := locator
|
||||
if strings.HasPrefix(locator, "srv://") {
|
||||
service = locator[6:]
|
||||
@@ -122,7 +126,7 @@ func (rp *ReverseProxy) srvDialerFunc(locator string) func(network, addr string)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return net.Dial("tcp", fmt.Sprintf("%s:%d", addrs[0].Target, addrs[0].Port))
|
||||
return net.DialTimeout("tcp", fmt.Sprintf("%s:%d", addrs[0].Target, addrs[0].Port), timeout)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -144,7 +148,7 @@ func singleJoiningSlash(a, b string) string {
|
||||
// the target request will be for /base/dir.
|
||||
// Without logic: target's path is "/", incoming is "/api/messages",
|
||||
// without is "/api", then the target request will be for /messages.
|
||||
func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *ReverseProxy {
|
||||
func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int, timeout, fallbackDelay time.Duration) *ReverseProxy {
|
||||
targetQuery := target.RawQuery
|
||||
director := func(req *http.Request) {
|
||||
if target.Scheme == "unix" {
|
||||
@@ -226,15 +230,24 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
|
||||
}
|
||||
}
|
||||
|
||||
dialer := *defaultDialer
|
||||
if timeout != defaultDialer.Timeout {
|
||||
dialer.Timeout = timeout
|
||||
}
|
||||
if fallbackDelay != defaultDialer.FallbackDelay {
|
||||
dialer.FallbackDelay = fallbackDelay
|
||||
}
|
||||
|
||||
rp := &ReverseProxy{
|
||||
Director: director,
|
||||
FlushInterval: 250 * time.Millisecond, // flushing good for streaming & server-sent events
|
||||
srvResolver: net.DefaultResolver,
|
||||
dialer: &dialer,
|
||||
}
|
||||
|
||||
if target.Scheme == "unix" {
|
||||
rp.Transport = &http.Transport{
|
||||
Dial: socketDial(target.String()),
|
||||
Dial: socketDial(target.String(), timeout),
|
||||
}
|
||||
} else if target.Scheme == "quic" {
|
||||
rp.Transport = &h2quic.RoundTripper{
|
||||
@@ -244,9 +257,9 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
|
||||
},
|
||||
}
|
||||
} else if keepalive != http.DefaultMaxIdleConnsPerHost || strings.HasPrefix(target.Scheme, "srv") {
|
||||
dialFunc := defaultDialer.Dial
|
||||
dialFunc := rp.dialer.Dial
|
||||
if strings.HasPrefix(target.Scheme, "srv") {
|
||||
dialFunc = rp.srvDialerFunc(target.String())
|
||||
dialFunc = rp.srvDialerFunc(target.String(), timeout)
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
@@ -264,6 +277,15 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
|
||||
http2.ConfigureTransport(transport)
|
||||
}
|
||||
rp.Transport = transport
|
||||
} else {
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
Dial: rp.dialer.Dial,
|
||||
}
|
||||
if httpserver.HTTP2 {
|
||||
http2.ConfigureTransport(transport)
|
||||
}
|
||||
rp.Transport = transport
|
||||
}
|
||||
return rp
|
||||
}
|
||||
@@ -272,18 +294,7 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
|
||||
// when it is OK for upstream to be using a bad certificate,
|
||||
// since this transport skips verification.
|
||||
func (rp *ReverseProxy) UseInsecureTransport() {
|
||||
if rp.Transport == nil {
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
Dial: defaultDialer.Dial,
|
||||
TLSHandshakeTimeout: defaultCryptoHandshakeTimeout,
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
if httpserver.HTTP2 {
|
||||
http2.ConfigureTransport(transport)
|
||||
}
|
||||
rp.Transport = transport
|
||||
} else if transport, ok := rp.Transport.(*http.Transport); ok {
|
||||
if transport, ok := rp.Transport.(*http.Transport); ok {
|
||||
if transport.TLSClientConfig == nil {
|
||||
transport.TLSClientConfig = &tls.Config{}
|
||||
}
|
||||
@@ -305,8 +316,6 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request,
|
||||
transport := rp.Transport
|
||||
if requestIsWebsocket(outreq) {
|
||||
transport = newConnHijackerTransport(transport)
|
||||
} else if transport == nil {
|
||||
transport = http.DefaultTransport
|
||||
}
|
||||
|
||||
rp.Director(outreq)
|
||||
@@ -361,7 +370,7 @@ func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, outreq *http.Request,
|
||||
}
|
||||
bufferPool.Put(hj.Replay)
|
||||
} else {
|
||||
backendConn, err = net.Dial("tcp", outreq.URL.Host)
|
||||
backendConn, err = net.DialTimeout("tcp", outreq.URL.Host, rp.dialer.Timeout)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"net/url"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -66,7 +67,7 @@ func TestSingleSRVHostReverseProxy(t *testing.T) {
|
||||
}
|
||||
port := uint16(pp)
|
||||
|
||||
rp := NewSingleHostReverseProxy(target, "", http.DefaultMaxIdleConnsPerHost)
|
||||
rp := NewSingleHostReverseProxy(target, "", http.DefaultMaxIdleConnsPerHost, 30*time.Second, 300*time.Millisecond)
|
||||
rp.srvResolver = testResolver{
|
||||
result: []*net.SRV{
|
||||
{Target: upstream.Hostname(), Port: port, Priority: 1, Weight: 1},
|
||||
|
||||
@@ -49,6 +49,8 @@ type staticUpstream struct {
|
||||
Hosts HostPool
|
||||
Policy Policy
|
||||
KeepAlive int
|
||||
FallbackDelay time.Duration
|
||||
Timeout time.Duration
|
||||
FailTimeout time.Duration
|
||||
TryDuration time.Duration
|
||||
TryInterval time.Duration
|
||||
@@ -92,6 +94,7 @@ func NewStaticUpstreams(c caddyfile.Dispenser, host string) ([]Upstream, error)
|
||||
TryInterval: 250 * time.Millisecond,
|
||||
MaxConns: 0,
|
||||
KeepAlive: http.DefaultMaxIdleConnsPerHost,
|
||||
Timeout: 30 * time.Second,
|
||||
resolver: net.DefaultResolver,
|
||||
}
|
||||
|
||||
@@ -225,7 +228,7 @@ func (u *staticUpstream) NewHost(host string) (*UpstreamHost, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
uh.ReverseProxy = NewSingleHostReverseProxy(baseURL, uh.WithoutPathPrefix, u.KeepAlive)
|
||||
uh.ReverseProxy = NewSingleHostReverseProxy(baseURL, uh.WithoutPathPrefix, u.KeepAlive, u.Timeout, u.FallbackDelay)
|
||||
if u.insecureSkipVerify {
|
||||
uh.ReverseProxy.UseInsecureTransport()
|
||||
}
|
||||
@@ -307,6 +310,15 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream, hasSrv bool) error {
|
||||
arg = c.Val()
|
||||
}
|
||||
u.Policy = policyCreateFunc(arg)
|
||||
case "fallback_delay":
|
||||
if !c.NextArg() {
|
||||
return c.ArgErr()
|
||||
}
|
||||
dur, err := time.ParseDuration(c.Val())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
u.FallbackDelay = dur
|
||||
case "fail_timeout":
|
||||
if !c.NextArg() {
|
||||
return c.ArgErr()
|
||||
@@ -431,9 +443,10 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream, hasSrv bool) error {
|
||||
}
|
||||
u.downstreamHeaders.Add(header, value)
|
||||
case "transparent":
|
||||
// Note: X-Forwarded-For header is always being appended for proxy connections
|
||||
// See implementation of createUpstreamRequest in proxy.go
|
||||
u.upstreamHeaders.Add("Host", "{host}")
|
||||
u.upstreamHeaders.Add("X-Real-IP", "{remote}")
|
||||
u.upstreamHeaders.Add("X-Forwarded-For", "{remote}")
|
||||
u.upstreamHeaders.Add("X-Forwarded-Proto", "{scheme}")
|
||||
case "websocket":
|
||||
u.upstreamHeaders.Add("Connection", "{>Connection}")
|
||||
@@ -463,6 +476,15 @@ func parseBlock(c *caddyfile.Dispenser, u *staticUpstream, hasSrv bool) error {
|
||||
return c.ArgErr()
|
||||
}
|
||||
u.KeepAlive = n
|
||||
case "timeout":
|
||||
if !c.NextArg() {
|
||||
return c.ArgErr()
|
||||
}
|
||||
dur, err := time.ParseDuration(c.Val())
|
||||
if err != nil {
|
||||
return c.Errf("unable to parse timeout duration '%s'", c.Val())
|
||||
}
|
||||
u.Timeout = dur
|
||||
default:
|
||||
return c.Errf("unknown property '%s'", c.Val())
|
||||
}
|
||||
@@ -608,6 +630,11 @@ func (u *staticUpstream) AllowedPath(requestPath string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// GetFallbackDelay returns u.FallbackDelay.
|
||||
func (u *staticUpstream) GetFallbackDelay() time.Duration {
|
||||
return u.FallbackDelay
|
||||
}
|
||||
|
||||
// GetTryDuration returns u.TryDuration.
|
||||
func (u *staticUpstream) GetTryDuration() time.Duration {
|
||||
return u.TryDuration
|
||||
@@ -618,6 +645,11 @@ func (u *staticUpstream) GetTryInterval() time.Duration {
|
||||
return u.TryInterval
|
||||
}
|
||||
|
||||
// GetTimeout returns u.Timeout.
|
||||
func (u *staticUpstream) GetTimeout() time.Duration {
|
||||
return u.Timeout
|
||||
}
|
||||
|
||||
func (u *staticUpstream) GetHostCount() int {
|
||||
return len(u.Hosts)
|
||||
}
|
||||
|
||||
@@ -282,7 +282,8 @@ func TestStop(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseBlock(t *testing.T) {
|
||||
func TestParseBlockTransparent(t *testing.T) {
|
||||
// tests for transparent proxy presets
|
||||
r, _ := http.NewRequest("GET", "/", nil)
|
||||
tests := []struct {
|
||||
config string
|
||||
@@ -316,6 +317,10 @@ func TestParseBlock(t *testing.T) {
|
||||
if _, ok := headers["X-Forwarded-Proto"]; !ok {
|
||||
t.Errorf("Test %d: Could not find the X-Forwarded-Proto header", i+1)
|
||||
}
|
||||
|
||||
if _, ok := headers["X-Forwarded-For"]; ok {
|
||||
t.Errorf("Test %d: Found unexpected X-Forwarded-For header", i+1)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -63,22 +63,38 @@ type Rule interface {
|
||||
|
||||
// SimpleRule is a simple rewrite rule.
|
||||
type SimpleRule struct {
|
||||
From, To string
|
||||
Regexp *regexp.Regexp
|
||||
To string
|
||||
Negate bool
|
||||
}
|
||||
|
||||
// NewSimpleRule creates a new Simple Rule
|
||||
func NewSimpleRule(from, to string) SimpleRule {
|
||||
return SimpleRule{from, to}
|
||||
func NewSimpleRule(from, to string, negate bool) (*SimpleRule, error) {
|
||||
r, err := regexp.Compile(from)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &SimpleRule{
|
||||
Regexp: r,
|
||||
To: to,
|
||||
Negate: negate,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// BasePath satisfies httpserver.Config
|
||||
func (s SimpleRule) BasePath() string { return s.From }
|
||||
func (s SimpleRule) BasePath() string { return "/" }
|
||||
|
||||
// Match satisfies httpserver.Config
|
||||
func (s SimpleRule) Match(r *http.Request) bool { return s.From == r.URL.Path }
|
||||
func (s *SimpleRule) Match(r *http.Request) bool {
|
||||
matches := regexpMatches(s.Regexp, "/", r.URL.Path)
|
||||
if s.Negate {
|
||||
return len(matches) == 0
|
||||
}
|
||||
return len(matches) > 0
|
||||
}
|
||||
|
||||
// Rewrite rewrites the internal location of the current request.
|
||||
func (s SimpleRule) Rewrite(fs http.FileSystem, r *http.Request) Result {
|
||||
func (s *SimpleRule) Rewrite(fs http.FileSystem, r *http.Request) Result {
|
||||
|
||||
// attempt rewrite
|
||||
return To(fs, r, s.To, newReplacer(r))
|
||||
@@ -165,7 +181,7 @@ func (r ComplexRule) Match(req *http.Request) bool {
|
||||
return true
|
||||
}
|
||||
// otherwise validate regex
|
||||
return r.regexpMatches(req.URL.Path) != nil
|
||||
return regexpMatches(r.Regexp, r.Base, req.URL.Path) != nil
|
||||
}
|
||||
|
||||
// Rewrite rewrites the internal location of the current request.
|
||||
@@ -174,7 +190,7 @@ func (r ComplexRule) Rewrite(fs http.FileSystem, req *http.Request) (re Result)
|
||||
|
||||
// validate regexp if present
|
||||
if r.Regexp != nil {
|
||||
matches := r.regexpMatches(req.URL.Path)
|
||||
matches := regexpMatches(r.Regexp, r.Base, req.URL.Path)
|
||||
switch len(matches) {
|
||||
case 0:
|
||||
// no match
|
||||
@@ -230,14 +246,14 @@ func (r ComplexRule) matchExt(rPath string) bool {
|
||||
return !mustUse
|
||||
}
|
||||
|
||||
func (r ComplexRule) regexpMatches(rPath string) []string {
|
||||
if r.Regexp != nil {
|
||||
func regexpMatches(regexp *regexp.Regexp, base, rPath string) []string {
|
||||
if regexp != nil {
|
||||
// include trailing slash in regexp if present
|
||||
start := len(r.Base)
|
||||
if strings.HasSuffix(r.Base, "/") {
|
||||
start := len(base)
|
||||
if strings.HasSuffix(base, "/") {
|
||||
start--
|
||||
}
|
||||
return r.Regexp.FindStringSubmatch(rPath[start:])
|
||||
return regexp.FindStringSubmatch(rPath[start:])
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -29,9 +29,9 @@ func TestRewrite(t *testing.T) {
|
||||
rw := Rewrite{
|
||||
Next: httpserver.HandlerFunc(urlPrinter),
|
||||
Rules: []httpserver.HandlerConfig{
|
||||
NewSimpleRule("/from", "/to"),
|
||||
NewSimpleRule("/a", "/b"),
|
||||
NewSimpleRule("/b", "/b{uri}"),
|
||||
newSimpleRule(t, "^/from$", "/to"),
|
||||
newSimpleRule(t, "^/a$", "/b"),
|
||||
newSimpleRule(t, "^/b$", "/b{uri}"),
|
||||
},
|
||||
FileSys: http.Dir("."),
|
||||
}
|
||||
@@ -131,6 +131,45 @@ func TestRewrite(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestWordpress is a test for wordpress usecase.
|
||||
func TestWordpress(t *testing.T) {
|
||||
rw := Rewrite{
|
||||
Next: httpserver.HandlerFunc(urlPrinter),
|
||||
Rules: []httpserver.HandlerConfig{
|
||||
// both rules are same, thanks to Go regexp (confusion).
|
||||
newSimpleRule(t, "^/wp-admin", "{path} {path}/ /index.php?{query}", true),
|
||||
newSimpleRule(t, "^\\/wp-admin", "{path} {path}/ /index.php?{query}", true),
|
||||
},
|
||||
FileSys: http.Dir("."),
|
||||
}
|
||||
tests := []struct {
|
||||
from string
|
||||
expectedTo string
|
||||
}{
|
||||
{"/wp-admin", "/wp-admin"},
|
||||
{"/wp-admin/login.php", "/wp-admin/login.php"},
|
||||
{"/not-wp-admin/login.php?not=admin", "/index.php?not=admin"},
|
||||
{"/loophole", "/index.php"},
|
||||
{"/user?name=john", "/index.php?name=john"},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
req, err := http.NewRequest("GET", test.from, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Test %d: Could not create HTTP request: %v", i, err)
|
||||
}
|
||||
ctx := context.WithValue(req.Context(), httpserver.OriginalURLCtxKey, *req.URL)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
rw.ServeHTTP(rec, req)
|
||||
|
||||
if got, want := rec.Body.String(), test.expectedTo; got != want {
|
||||
t.Errorf("Test %d: Expected URL to be '%s' but was '%s'", i, want, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func urlPrinter(w http.ResponseWriter, r *http.Request) (int, error) {
|
||||
fmt.Fprint(w, r.URL.String())
|
||||
return 0, nil
|
||||
|
||||
@@ -58,6 +58,7 @@ func rewriteParse(c *caddy.Controller) ([]httpserver.HandlerConfig, error) {
|
||||
var base = "/"
|
||||
var pattern, to string
|
||||
var ext []string
|
||||
var negate bool
|
||||
|
||||
args := c.RemainingArgs()
|
||||
|
||||
@@ -111,7 +112,14 @@ func rewriteParse(c *caddy.Controller) ([]httpserver.HandlerConfig, error) {
|
||||
|
||||
// the only unhandled case is 2 and above
|
||||
default:
|
||||
rule = NewSimpleRule(args[0], strings.Join(args[1:], " "))
|
||||
if args[0] == "not" {
|
||||
negate = true
|
||||
args = args[1:]
|
||||
}
|
||||
rule, err = NewSimpleRule(args[0], strings.Join(args[1:], " "), negate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rules = append(rules, rule)
|
||||
}
|
||||
|
||||
|
||||
@@ -50,6 +50,19 @@ func TestSetup(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// newSimpleRule is convenience test function for SimpleRule.
|
||||
func newSimpleRule(t *testing.T, from, to string, negate ...bool) Rule {
|
||||
var n bool
|
||||
if len(negate) > 0 {
|
||||
n = negate[0]
|
||||
}
|
||||
rule, err := NewSimpleRule(from, to, n)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return rule
|
||||
}
|
||||
|
||||
func TestRewriteParse(t *testing.T) {
|
||||
simpleTests := []struct {
|
||||
input string
|
||||
@@ -57,17 +70,20 @@ func TestRewriteParse(t *testing.T) {
|
||||
expected []Rule
|
||||
}{
|
||||
{`rewrite /from /to`, false, []Rule{
|
||||
SimpleRule{From: "/from", To: "/to"},
|
||||
newSimpleRule(t, "/from", "/to"),
|
||||
}},
|
||||
{`rewrite /from /to
|
||||
rewrite a b`, false, []Rule{
|
||||
SimpleRule{From: "/from", To: "/to"},
|
||||
SimpleRule{From: "a", To: "b"},
|
||||
newSimpleRule(t, "/from", "/to"),
|
||||
newSimpleRule(t, "a", "b"),
|
||||
}},
|
||||
{`rewrite a`, true, []Rule{}},
|
||||
{`rewrite`, true, []Rule{}},
|
||||
{`rewrite a b c`, false, []Rule{
|
||||
SimpleRule{From: "a", To: "b c"},
|
||||
newSimpleRule(t, "a", "b c"),
|
||||
}},
|
||||
{`rewrite not a b c`, false, []Rule{
|
||||
newSimpleRule(t, "a", "b c", true),
|
||||
}},
|
||||
}
|
||||
|
||||
@@ -88,17 +104,22 @@ func TestRewriteParse(t *testing.T) {
|
||||
}
|
||||
|
||||
for j, e := range test.expected {
|
||||
actualRule := actual[j].(SimpleRule)
|
||||
expectedRule := e.(SimpleRule)
|
||||
actualRule := actual[j].(*SimpleRule)
|
||||
expectedRule := e.(*SimpleRule)
|
||||
|
||||
if actualRule.From != expectedRule.From {
|
||||
if actualRule.Regexp.String() != expectedRule.Regexp.String() {
|
||||
t.Errorf("Test %d, rule %d: Expected From=%s, got %s",
|
||||
i, j, expectedRule.From, actualRule.From)
|
||||
i, j, expectedRule.Regexp.String(), actualRule.Regexp.String())
|
||||
}
|
||||
|
||||
if actualRule.To != expectedRule.To {
|
||||
t.Errorf("Test %d, rule %d: Expected To=%s, got %s",
|
||||
i, j, expectedRule.To, actualRule.To)
|
||||
i, j, expectedRule.Regexp.String(), actualRule.Regexp.String())
|
||||
}
|
||||
|
||||
if actualRule.Negate != expectedRule.Negate {
|
||||
t.Errorf("Test %d, rule %d: Expected Negate=%v, got %v",
|
||||
i, j, expectedRule.Negate, actualRule.Negate)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -110,6 +110,10 @@ func (t Templates) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error
|
||||
// set the actual content length now that the template was executed
|
||||
w.Header().Set("Content-Length", strconv.Itoa(buf.Len()))
|
||||
|
||||
// delete the headers related to cache
|
||||
w.Header().Del("ETag")
|
||||
w.Header().Del("Last-Modified")
|
||||
|
||||
// get the modification time in preparation for http.ServeContent
|
||||
modTime, _ := time.Parse(http.TimeFormat, w.Header().Get("Last-Modified"))
|
||||
|
||||
|
||||
@@ -70,6 +70,7 @@ func TestTemplates(t *testing.T) {
|
||||
req string
|
||||
respCode int
|
||||
res string
|
||||
bypass bool
|
||||
}{
|
||||
{
|
||||
tpl: tmpl,
|
||||
@@ -113,6 +114,7 @@ func TestTemplates(t *testing.T) {
|
||||
respCode: http.StatusOK,
|
||||
res: `<!DOCTYPE html><html><head><title>as it is</title></head><body>{{.Include "header.html"}}</body></html>
|
||||
`,
|
||||
bypass: true,
|
||||
},
|
||||
} {
|
||||
c := c
|
||||
@@ -135,6 +137,14 @@ func TestTemplates(t *testing.T) {
|
||||
if respBody != c.res {
|
||||
t.Fatalf("Test: the expected body %v is different from the response one: %v", c.res, respBody)
|
||||
}
|
||||
|
||||
if !c.bypass {
|
||||
eTag := rec.Header().Get("ETag")
|
||||
lastModified := rec.Header().Get("Last-Modified")
|
||||
if eTag != "" || lastModified != "" {
|
||||
t.Fatalf("Test: expect a response without ETag or Last-Modified, got %v %v", eTag, lastModified)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,87 +0,0 @@
|
||||
// Copyright 2015 Light Code Labs, LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package caddytls
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestUnexportedGetCertificate(t *testing.T) {
|
||||
certCache := &certificateCache{cache: make(map[string]Certificate)}
|
||||
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
|
||||
|
||||
// When cache is empty
|
||||
if _, matched, defaulted := cfg.getCertificate("example.com"); matched || defaulted {
|
||||
t.Errorf("Got a certificate when cache was empty; matched=%v, defaulted=%v", matched, defaulted)
|
||||
}
|
||||
|
||||
// When cache has one certificate in it
|
||||
firstCert := Certificate{Names: []string{"example.com"}}
|
||||
certCache.cache["0xdeadbeef"] = firstCert
|
||||
cfg.Certificates["example.com"] = "0xdeadbeef"
|
||||
if cert, matched, defaulted := cfg.getCertificate("Example.com"); !matched || defaulted || cert.Names[0] != "example.com" {
|
||||
t.Errorf("Didn't get a cert for 'Example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
|
||||
}
|
||||
if cert, matched, defaulted := cfg.getCertificate("example.com"); !matched || defaulted || cert.Names[0] != "example.com" {
|
||||
t.Errorf("Didn't get a cert for 'example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
|
||||
}
|
||||
|
||||
// When retrieving wildcard certificate
|
||||
certCache.cache["0xb01dface"] = Certificate{Names: []string{"*.example.com"}}
|
||||
cfg.Certificates["*.example.com"] = "0xb01dface"
|
||||
if cert, matched, defaulted := cfg.getCertificate("sub.example.com"); !matched || defaulted || cert.Names[0] != "*.example.com" {
|
||||
t.Errorf("Didn't get wildcard cert for 'sub.example.com' or got the wrong one: %v, matched=%v, defaulted=%v", cert, matched, defaulted)
|
||||
}
|
||||
|
||||
// When no certificate matches and SNI is provided, return no certificate (should be TLS alert)
|
||||
if cert, matched, defaulted := cfg.getCertificate("nomatch"); matched || defaulted {
|
||||
t.Errorf("Expected matched=false, defaulted=false; but got matched=%v, defaulted=%v (cert: %v)", matched, defaulted, cert)
|
||||
}
|
||||
|
||||
// When no certificate matches and SNI is NOT provided, a random is returned
|
||||
if cert, matched, defaulted := cfg.getCertificate(""); matched || !defaulted {
|
||||
t.Errorf("Expected matched=false, defaulted=true; but got matched=%v, defaulted=%v (cert: %v)", matched, defaulted, cert)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheCertificate(t *testing.T) {
|
||||
certCache := &certificateCache{cache: make(map[string]Certificate)}
|
||||
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
|
||||
|
||||
cfg.cacheCertificate(Certificate{Names: []string{"example.com", "sub.example.com"}, Hash: "foobar"})
|
||||
if len(certCache.cache) != 1 {
|
||||
t.Errorf("Expected length of certificate cache to be 1")
|
||||
}
|
||||
if _, ok := certCache.cache["foobar"]; !ok {
|
||||
t.Error("Expected first cert to be cached by key 'foobar', but it wasn't")
|
||||
}
|
||||
if _, ok := cfg.Certificates["example.com"]; !ok {
|
||||
t.Error("Expected first cert to be keyed by 'example.com', but it wasn't")
|
||||
}
|
||||
if _, ok := cfg.Certificates["sub.example.com"]; !ok {
|
||||
t.Error("Expected first cert to be keyed by 'sub.example.com', but it wasn't")
|
||||
}
|
||||
|
||||
// different config, but using same cache; and has cert with overlapping name,
|
||||
// but different hash
|
||||
cfg2 := &Config{Certificates: make(map[string]string), certCache: certCache}
|
||||
cfg2.cacheCertificate(Certificate{Names: []string{"example.com"}, Hash: "barbaz"})
|
||||
if _, ok := certCache.cache["barbaz"]; !ok {
|
||||
t.Error("Expected second cert to be cached by key 'barbaz.com', but it wasn't")
|
||||
}
|
||||
if hash, ok := cfg2.Certificates["example.com"]; !ok {
|
||||
t.Error("Expected second cert to be keyed by 'example.com', but it wasn't")
|
||||
} else if hash != "barbaz" {
|
||||
t.Errorf("Expected second cert to map to 'barbaz' but it was %s instead", hash)
|
||||
}
|
||||
}
|
||||
@@ -1,416 +0,0 @@
|
||||
// Copyright 2015 Light Code Labs, LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package caddytls
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/xenolf/lego/acme"
|
||||
)
|
||||
|
||||
// acmeMu ensures that only one ACME challenge occurs at a time.
|
||||
var acmeMu sync.Mutex
|
||||
|
||||
// ACMEClient is a wrapper over acme.Client with
|
||||
// some custom state attached. It is used to obtain,
|
||||
// renew, and revoke certificates with ACME.
|
||||
type ACMEClient struct {
|
||||
AllowPrompts bool
|
||||
config *Config
|
||||
acmeClient *acme.Client
|
||||
storage Storage
|
||||
}
|
||||
|
||||
// newACMEClient creates a new ACMEClient given an email and whether
|
||||
// prompting the user is allowed. It's a variable so we can mock in tests.
|
||||
var newACMEClient = func(config *Config, allowPrompts bool) (*ACMEClient, error) {
|
||||
storage, err := config.StorageFor(config.CAUrl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Look up or create the LE user account
|
||||
leUser, err := getUser(storage, config.ACMEEmail)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// ensure key type is set
|
||||
keyType := DefaultKeyType
|
||||
if config.KeyType != "" {
|
||||
keyType = config.KeyType
|
||||
}
|
||||
|
||||
// ensure CA URL (directory endpoint) is set
|
||||
caURL := DefaultCAUrl
|
||||
if config.CAUrl != "" {
|
||||
caURL = config.CAUrl
|
||||
}
|
||||
|
||||
// ensure endpoint is secure (assume HTTPS if scheme is missing)
|
||||
if !strings.Contains(caURL, "://") {
|
||||
caURL = "https://" + caURL
|
||||
}
|
||||
u, err := url.Parse(caURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if u.Scheme != "https" && !caddy.IsLoopback(u.Host) && !caddy.IsInternal(u.Host) {
|
||||
return nil, fmt.Errorf("%s: insecure CA URL (HTTPS required)", caURL)
|
||||
}
|
||||
|
||||
// The client facilitates our communication with the CA server.
|
||||
client, err := acme.NewClient(caURL, &leUser, keyType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If not registered, the user must register an account with the CA
|
||||
// and agree to terms
|
||||
if leUser.Registration == nil {
|
||||
reg, err := client.Register()
|
||||
if err != nil {
|
||||
return nil, errors.New("registration error: " + err.Error())
|
||||
}
|
||||
leUser.Registration = reg
|
||||
|
||||
if allowPrompts { // can't prompt a user who isn't there
|
||||
if !Agreed && reg.TosURL == "" {
|
||||
Agreed = promptUserAgreement(saURL, false) // TODO - latest URL
|
||||
}
|
||||
if !Agreed && reg.TosURL == "" {
|
||||
return nil, errors.New("user must agree to terms")
|
||||
}
|
||||
}
|
||||
|
||||
err = client.AgreeToTOS()
|
||||
if err != nil {
|
||||
saveUser(storage, leUser) // Might as well try, right?
|
||||
return nil, errors.New("error agreeing to terms: " + err.Error())
|
||||
}
|
||||
|
||||
// save user to the file system
|
||||
err = saveUser(storage, leUser)
|
||||
if err != nil {
|
||||
return nil, errors.New("could not save user: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
c := &ACMEClient{
|
||||
AllowPrompts: allowPrompts,
|
||||
config: config,
|
||||
acmeClient: client,
|
||||
storage: storage,
|
||||
}
|
||||
|
||||
if config.DNSProvider == "" {
|
||||
// Use HTTP and TLS-SNI challenges by default
|
||||
|
||||
// See if HTTP challenge needs to be proxied
|
||||
useHTTPPort := HTTPChallengePort
|
||||
if config.AltHTTPPort != "" {
|
||||
useHTTPPort = config.AltHTTPPort
|
||||
}
|
||||
if caddy.HasListenerWithAddress(net.JoinHostPort(config.ListenHost, useHTTPPort)) {
|
||||
useHTTPPort = DefaultHTTPAlternatePort
|
||||
}
|
||||
|
||||
// See which port TLS-SNI challenges will be accomplished on
|
||||
useTLSSNIPort := TLSSNIChallengePort
|
||||
if config.AltTLSSNIPort != "" {
|
||||
useTLSSNIPort = config.AltTLSSNIPort
|
||||
}
|
||||
|
||||
// Always respect user's bind preferences by using config.ListenHost.
|
||||
// NOTE(Sep'16): At time of writing, SetHTTPAddress() and SetTLSAddress()
|
||||
// must be called before SetChallengeProvider(), since they reset the
|
||||
// challenge provider back to the default one!
|
||||
err := c.acmeClient.SetHTTPAddress(net.JoinHostPort(config.ListenHost, useHTTPPort))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = c.acmeClient.SetTLSAddress(net.JoinHostPort(config.ListenHost, useTLSSNIPort))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// See if TLS challenge needs to be handled by our own facilities
|
||||
if caddy.HasListenerWithAddress(net.JoinHostPort(config.ListenHost, useTLSSNIPort)) {
|
||||
c.acmeClient.SetChallengeProvider(acme.TLSSNI01, tlsSNISolver{certCache: config.certCache})
|
||||
}
|
||||
|
||||
// Disable any challenges that should not be used
|
||||
var disabledChallenges []acme.Challenge
|
||||
if DisableHTTPChallenge {
|
||||
disabledChallenges = append(disabledChallenges, acme.HTTP01)
|
||||
}
|
||||
if DisableTLSSNIChallenge {
|
||||
disabledChallenges = append(disabledChallenges, acme.TLSSNI01)
|
||||
}
|
||||
if len(disabledChallenges) > 0 {
|
||||
c.acmeClient.ExcludeChallenges(disabledChallenges)
|
||||
}
|
||||
} else {
|
||||
// Otherwise, use DNS challenge exclusively
|
||||
|
||||
// Load provider constructor function
|
||||
provFn, ok := dnsProviders[config.DNSProvider]
|
||||
if !ok {
|
||||
return nil, errors.New("unknown DNS provider by name '" + config.DNSProvider + "'")
|
||||
}
|
||||
|
||||
// We could pass credentials to create the provider, but for now
|
||||
// just let the solver package get them from the environment
|
||||
prov, err := provFn()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Use the DNS challenge exclusively
|
||||
c.acmeClient.ExcludeChallenges([]acme.Challenge{acme.HTTP01, acme.TLSSNI01})
|
||||
c.acmeClient.SetChallengeProvider(acme.DNS01, prov)
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// Obtain obtains a single certificate for name. It stores the certificate
|
||||
// on the disk if successful. This function is safe for concurrent use.
|
||||
//
|
||||
// Right now our storage mechanism only supports one name per certificate,
|
||||
// so this function (along with Renew and Revoke) only accepts one domain
|
||||
// as input. It can be easily modified to support SAN certificates if our
|
||||
// storage mechanism is upgraded later.
|
||||
//
|
||||
// Callers who have access to a Config value should use the ObtainCert
|
||||
// method on that instead of this lower-level method.
|
||||
func (c *ACMEClient) Obtain(name string) error {
|
||||
waiter, err := c.storage.TryLock(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if waiter != nil {
|
||||
log.Printf("[INFO] Certificate for %s is already being obtained elsewhere and stored; waiting", name)
|
||||
waiter.Wait()
|
||||
return nil // we assume the process with the lock succeeded, rather than hammering this execution path again
|
||||
}
|
||||
defer func() {
|
||||
if err := c.storage.Unlock(name); err != nil {
|
||||
log.Printf("[ERROR] Unable to unlock obtain call for %s: %v", name, err)
|
||||
}
|
||||
}()
|
||||
|
||||
Attempts:
|
||||
for attempts := 0; attempts < 2; attempts++ {
|
||||
namesObtaining.Add([]string{name})
|
||||
acmeMu.Lock()
|
||||
certificate, failures := c.acmeClient.ObtainCertificate([]string{name}, true, nil, c.config.MustStaple)
|
||||
acmeMu.Unlock()
|
||||
namesObtaining.Remove([]string{name})
|
||||
if len(failures) > 0 {
|
||||
// Error - try to fix it or report it to the user and abort
|
||||
var errMsg string // we'll combine all the failures into a single error message
|
||||
var promptedForAgreement bool // only prompt user for agreement at most once
|
||||
|
||||
for errDomain, obtainErr := range failures {
|
||||
if obtainErr == nil {
|
||||
continue
|
||||
}
|
||||
if tosErr, ok := obtainErr.(acme.TOSError); ok {
|
||||
// Terms of Service agreement error; we can probably deal with this
|
||||
if !Agreed && !promptedForAgreement && c.AllowPrompts {
|
||||
Agreed = promptUserAgreement(tosErr.Detail, true) // TODO: Use latest URL
|
||||
promptedForAgreement = true
|
||||
}
|
||||
if Agreed || !c.AllowPrompts {
|
||||
err := c.acmeClient.AgreeToTOS()
|
||||
if err != nil {
|
||||
return errors.New("error agreeing to updated terms: " + err.Error())
|
||||
}
|
||||
continue Attempts
|
||||
}
|
||||
}
|
||||
|
||||
// If user did not agree or it was any other kind of error, just append to the list of errors
|
||||
errMsg += "[" + errDomain + "] failed to get certificate: " + obtainErr.Error() + "\n"
|
||||
}
|
||||
return errors.New(errMsg)
|
||||
}
|
||||
|
||||
// Success - immediately save the certificate resource
|
||||
err = saveCertResource(c.storage, certificate)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error saving assets for %v: %v", name, err)
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Renew renews the managed certificate for name. It puts the renewed
|
||||
// certificate into storage (not the cache). This function is safe for
|
||||
// concurrent use.
|
||||
//
|
||||
// Callers who have access to a Config value should use the RenewCert
|
||||
// method on that instead of this lower-level method.
|
||||
func (c *ACMEClient) Renew(name string) error {
|
||||
waiter, err := c.storage.TryLock(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if waiter != nil {
|
||||
log.Printf("[INFO] Certificate for %s is already being renewed elsewhere and stored; waiting", name)
|
||||
waiter.Wait()
|
||||
return nil // assume that the worker that renewed the cert succeeded; avoid hammering this path over and over
|
||||
}
|
||||
defer func() {
|
||||
if err := c.storage.Unlock(name); err != nil {
|
||||
log.Printf("[ERROR] Unable to unlock renew call for %s: %v", name, err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Prepare for renewal (load PEM cert, key, and meta)
|
||||
siteData, err := c.storage.LoadSite(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var certMeta acme.CertificateResource
|
||||
err = json.Unmarshal(siteData.Meta, &certMeta)
|
||||
certMeta.Certificate = siteData.Cert
|
||||
certMeta.PrivateKey = siteData.Key
|
||||
|
||||
// Perform renewal and retry if necessary, but not too many times.
|
||||
var newCertMeta acme.CertificateResource
|
||||
var success bool
|
||||
for attempts := 0; attempts < 2; attempts++ {
|
||||
namesObtaining.Add([]string{name})
|
||||
acmeMu.Lock()
|
||||
newCertMeta, err = c.acmeClient.RenewCertificate(certMeta, true, c.config.MustStaple)
|
||||
acmeMu.Unlock()
|
||||
namesObtaining.Remove([]string{name})
|
||||
if err == nil {
|
||||
success = true
|
||||
break
|
||||
}
|
||||
|
||||
// If the legal terms were updated and need to be
|
||||
// agreed to again, we can handle that.
|
||||
if _, ok := err.(acme.TOSError); ok {
|
||||
err := c.acmeClient.AgreeToTOS()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// For any other kind of error, wait 10s and try again.
|
||||
wait := 10 * time.Second
|
||||
log.Printf("[ERROR] Renewing: %v; trying again in %s", err, wait)
|
||||
time.Sleep(wait)
|
||||
}
|
||||
|
||||
if !success {
|
||||
return errors.New("too many renewal attempts; last error: " + err.Error())
|
||||
}
|
||||
|
||||
caddy.EmitEvent(caddy.CertRenewEvent, name)
|
||||
|
||||
return saveCertResource(c.storage, newCertMeta)
|
||||
}
|
||||
|
||||
// Revoke revokes the certificate for name and deletes
|
||||
// it from storage.
|
||||
func (c *ACMEClient) Revoke(name string) error {
|
||||
siteExists, err := c.storage.SiteExists(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !siteExists {
|
||||
return errors.New("no certificate and key for " + name)
|
||||
}
|
||||
|
||||
siteData, err := c.storage.LoadSite(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = c.acmeClient.RevokeCertificate(siteData.Cert)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = c.storage.DeleteSite(name)
|
||||
if err != nil {
|
||||
return errors.New("certificate revoked, but unable to delete certificate file: " + err.Error())
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// namesObtaining is a set of hostnames with thread-safe
|
||||
// methods. A name should be in this set only while this
|
||||
// package is in the process of obtaining a certificate
|
||||
// for the name. ACME challenges that are received for
|
||||
// names which are not in this set were not initiated by
|
||||
// this package and probably should not be handled by
|
||||
// this package.
|
||||
var namesObtaining = nameCoordinator{names: make(map[string]struct{})}
|
||||
|
||||
type nameCoordinator struct {
|
||||
names map[string]struct{}
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// Add adds names to c. It is safe for concurrent use.
|
||||
func (c *nameCoordinator) Add(names []string) {
|
||||
c.mu.Lock()
|
||||
for _, name := range names {
|
||||
c.names[strings.ToLower(name)] = struct{}{}
|
||||
}
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
// Remove removes names from c. It is safe for concurrent use.
|
||||
func (c *nameCoordinator) Remove(names []string) {
|
||||
c.mu.Lock()
|
||||
for _, name := range names {
|
||||
delete(c.names, strings.ToLower(name))
|
||||
}
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
// Has returns true if c has name. It is safe for concurrent use.
|
||||
func (c *nameCoordinator) Has(name string) bool {
|
||||
hostname, _, err := net.SplitHostPort(name)
|
||||
if err != nil {
|
||||
hostname = name
|
||||
}
|
||||
c.mu.RLock()
|
||||
_, ok := c.names[strings.ToLower(hostname)]
|
||||
c.mu.RUnlock()
|
||||
return ok
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright 2015 Light Code Labs, LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package caddytls
|
||||
|
||||
// TODO
|
||||
+86
-220
@@ -20,12 +20,12 @@ import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
|
||||
"net/url"
|
||||
"strings"
|
||||
"github.com/xenolf/lego/challenge/tlsalpn01"
|
||||
|
||||
"github.com/codahale/aesnicheck"
|
||||
"github.com/klauspost/cpuid"
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/xenolf/lego/acme"
|
||||
"github.com/mholt/certmagic"
|
||||
"github.com/xenolf/lego/certcrypto"
|
||||
)
|
||||
|
||||
// Config describes how TLS should be configured and used.
|
||||
@@ -63,102 +63,31 @@ type Config struct {
|
||||
// Manual means user provides own certs and keys
|
||||
Manual bool
|
||||
|
||||
// Managed means config qualifies for implicit,
|
||||
// automatic, managed TLS; as opposed to the user
|
||||
// providing and managing the certificate manually
|
||||
// Managed means this config should be managed
|
||||
// by the CertMagic Config (Manager field)
|
||||
Managed bool
|
||||
|
||||
// OnDemand means the class of hostnames this
|
||||
// config applies to may obtain and manage
|
||||
// certificates at handshake-time (as opposed
|
||||
// to pre-loaded at startup); OnDemand certs
|
||||
// will be managed the same way as preloaded
|
||||
// ones, however, if an OnDemand cert fails to
|
||||
// renew, it is removed from the in-memory
|
||||
// cache; if this is true, Managed must
|
||||
// necessarily be true
|
||||
OnDemand bool
|
||||
// Manager is how certificates are managed
|
||||
Manager *certmagic.Config
|
||||
|
||||
// SelfSigned means that this hostname is
|
||||
// served with a self-signed certificate
|
||||
// that we generated in memory for convenience
|
||||
SelfSigned bool
|
||||
|
||||
// The endpoint of the directory for the ACME
|
||||
// CA we are to use
|
||||
CAUrl string
|
||||
|
||||
// The host (ONLY the host, not port) to listen
|
||||
// on if necessary to start a listener to solve
|
||||
// an ACME challenge
|
||||
ListenHost string
|
||||
|
||||
// The alternate port (ONLY port, not host) to
|
||||
// use for the ACME HTTP challenge; if non-empty,
|
||||
// this port will be used instead of
|
||||
// HTTPChallengePort to spin up a listener for
|
||||
// the HTTP challenge
|
||||
AltHTTPPort string
|
||||
|
||||
// The alternate port (ONLY port, not host)
|
||||
// to use for the ACME TLS-SNI challenge.
|
||||
// The system must forward TLSSNIChallengePort
|
||||
// to this port for challenge to succeed
|
||||
AltTLSSNIPort string
|
||||
|
||||
// The string identifier of the DNS provider
|
||||
// to use when solving the ACME DNS challenge
|
||||
DNSProvider string
|
||||
|
||||
// The email address to use when creating or
|
||||
// using an ACME account (fun fact: if this
|
||||
// is set to "off" then this config will not
|
||||
// qualify for managed TLS)
|
||||
ACMEEmail string
|
||||
|
||||
// The type of key to use when generating
|
||||
// certificates
|
||||
KeyType acme.KeyType
|
||||
|
||||
// The storage creator; use StorageFor() to get a guaranteed
|
||||
// non-nil Storage instance. Note, Caddy may call this frequently
|
||||
// so implementors are encouraged to cache any heavy instantiations.
|
||||
StorageProvider string
|
||||
|
||||
// The state needed to operate on-demand TLS
|
||||
OnDemandState OnDemandState
|
||||
|
||||
// Add the must staple TLS extension to the CSR generated by lego/acme
|
||||
MustStaple bool
|
||||
|
||||
// The list of protocols to choose from for Application Layer
|
||||
// Protocol Negotiation (ALPN).
|
||||
ALPN []string
|
||||
|
||||
// The map of hostname to certificate hash. This is used to complete
|
||||
// handshakes and serve the right certificate given the SNI.
|
||||
Certificates map[string]string
|
||||
|
||||
certCache *certificateCache // pointer to the Instance's certificate store
|
||||
tlsConfig *tls.Config // the final tls.Config created with buildStandardTLSConfig()
|
||||
}
|
||||
|
||||
// OnDemandState contains some state relevant for providing
|
||||
// on-demand TLS.
|
||||
type OnDemandState struct {
|
||||
// The number of certificates that have been issued on-demand
|
||||
// by this config. It is only safe to modify this count atomically.
|
||||
// If it reaches MaxObtain, on-demand issuances must fail.
|
||||
ObtainedCount int32
|
||||
|
||||
// Set from max_certs in tls config, it specifies the
|
||||
// maximum number of certificates that can be issued.
|
||||
MaxObtain int32
|
||||
|
||||
// The url to call to check if an on-demand tls certificate should
|
||||
// be issued. If a request to the URL fails or returns a non 2xx
|
||||
// status on-demand issuances must fail.
|
||||
AskURL *url.URL
|
||||
// The final tls.Config created with
|
||||
// buildStandardTLSConfig()
|
||||
tlsConfig *tls.Config
|
||||
}
|
||||
|
||||
// NewConfig returns a new Config with a pointer to the instance's
|
||||
@@ -166,112 +95,21 @@ type OnDemandState struct {
|
||||
// the returned Config for successful practical use.
|
||||
func NewConfig(inst *caddy.Instance) *Config {
|
||||
inst.StorageMu.RLock()
|
||||
certCache, ok := inst.Storage[CertCacheInstStorageKey].(*certificateCache)
|
||||
certCache, ok := inst.Storage[CertCacheInstStorageKey].(*certmagic.Cache)
|
||||
inst.StorageMu.RUnlock()
|
||||
if !ok || certCache == nil {
|
||||
certCache = &certificateCache{cache: make(map[string]Certificate)}
|
||||
certCache = certmagic.NewCache(certmagic.DefaultStorage)
|
||||
inst.OnShutdown = append(inst.OnShutdown, func() error {
|
||||
certCache.Stop()
|
||||
return nil
|
||||
})
|
||||
inst.StorageMu.Lock()
|
||||
inst.Storage[CertCacheInstStorageKey] = certCache
|
||||
inst.StorageMu.Unlock()
|
||||
}
|
||||
cfg := new(Config)
|
||||
cfg.Certificates = make(map[string]string)
|
||||
cfg.certCache = certCache
|
||||
return cfg
|
||||
}
|
||||
|
||||
// ObtainCert obtains a certificate for name using c, as long
|
||||
// as a certificate does not already exist in storage for that
|
||||
// name. The name must qualify and c must be flagged as Managed.
|
||||
// This function is a no-op if storage already has a certificate
|
||||
// for name.
|
||||
//
|
||||
// It only obtains and stores certificates (and their keys),
|
||||
// it does not load them into memory. If allowPrompts is true,
|
||||
// the user may be shown a prompt.
|
||||
func (c *Config) ObtainCert(name string, allowPrompts bool) error {
|
||||
if !c.Managed || !HostQualifies(name) {
|
||||
return nil
|
||||
return &Config{
|
||||
Manager: certmagic.NewWithCache(certCache, certmagic.Config{}),
|
||||
}
|
||||
|
||||
storage, err := c.StorageFor(c.CAUrl)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
siteExists, err := storage.SiteExists(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if siteExists {
|
||||
return nil
|
||||
}
|
||||
if c.ACMEEmail == "" {
|
||||
c.ACMEEmail = getEmail(storage, allowPrompts)
|
||||
}
|
||||
|
||||
client, err := newACMEClient(c, allowPrompts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return client.Obtain(name)
|
||||
}
|
||||
|
||||
// RenewCert renews the certificate for name using c. It stows the
|
||||
// renewed certificate and its assets in storage if successful.
|
||||
func (c *Config) RenewCert(name string, allowPrompts bool) error {
|
||||
client, err := newACMEClient(c, allowPrompts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return client.Renew(name)
|
||||
}
|
||||
|
||||
// StorageFor obtains a TLS Storage instance for the given CA URL which should
|
||||
// be unique for every different ACME CA. If a StorageCreator is set on this
|
||||
// Config, it will be used. Otherwise the default file storage implementation
|
||||
// is used. When the error is nil, this is guaranteed to return a non-nil
|
||||
// Storage instance.
|
||||
func (c *Config) StorageFor(caURL string) (Storage, error) {
|
||||
// Validate CA URL
|
||||
if caURL == "" {
|
||||
caURL = DefaultCAUrl
|
||||
}
|
||||
if caURL == "" {
|
||||
return nil, fmt.Errorf("cannot create storage without CA URL")
|
||||
}
|
||||
caURL = strings.ToLower(caURL)
|
||||
|
||||
// scheme required or host will be parsed as path (as of Go 1.6)
|
||||
if !strings.Contains(caURL, "://") {
|
||||
caURL = "https://" + caURL
|
||||
}
|
||||
|
||||
u, err := url.Parse(caURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s: unable to parse CA URL: %v", caURL, err)
|
||||
}
|
||||
|
||||
if u.Host == "" {
|
||||
return nil, fmt.Errorf("%s: no host in CA URL", caURL)
|
||||
}
|
||||
|
||||
// Create the storage based on the URL
|
||||
var s Storage
|
||||
if c.StorageProvider == "" {
|
||||
c.StorageProvider = "file"
|
||||
}
|
||||
|
||||
creator, ok := storageProviders[c.StorageProvider]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%s: Unknown storage: %v", caURL, c.StorageProvider)
|
||||
}
|
||||
|
||||
s, err = creator(u)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s: unable to create custom storage '%v': %v", caURL, c.StorageProvider, err)
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// buildStandardTLSConfig converts cfg (*caddytls.Config) to a *tls.Config
|
||||
@@ -305,11 +143,23 @@ func (c *Config) buildStandardTLSConfig() error {
|
||||
}
|
||||
}
|
||||
|
||||
// ensure ALPN includes the ACME TLS-ALPN protocol
|
||||
var alpnFound bool
|
||||
for _, a := range c.ALPN {
|
||||
if a == tlsalpn01.ACMETLS1Protocol {
|
||||
alpnFound = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !alpnFound {
|
||||
c.ALPN = append(c.ALPN, tlsalpn01.ACMETLS1Protocol)
|
||||
}
|
||||
|
||||
config.MinVersion = c.ProtocolMinVersion
|
||||
config.MaxVersion = c.ProtocolMaxVersion
|
||||
config.ClientAuth = c.ClientAuth
|
||||
config.NextProtos = c.ALPN
|
||||
config.GetCertificate = c.GetCertificate
|
||||
config.GetCertificate = c.Manager.GetCertificate
|
||||
|
||||
// set up client authentication if enabled
|
||||
if config.ClientAuth != tls.NoClientCert {
|
||||
@@ -395,7 +245,7 @@ func MakeTLSConfig(configs []*Config) (*tls.Config, error) {
|
||||
// compatible), otherwise that is a configuration error
|
||||
if otherConfig, ok := configMap[cfg.Hostname]; ok {
|
||||
if err := assertConfigsCompatible(cfg, otherConfig); err != nil {
|
||||
return nil, fmt.Errorf("incompabile TLS configurations for the same SNI "+
|
||||
return nil, fmt.Errorf("incompatible TLS configurations for the same SNI "+
|
||||
"name (%s) on the same listener: %v",
|
||||
cfg.Hostname, err)
|
||||
}
|
||||
@@ -419,6 +269,13 @@ func MakeTLSConfig(configs []*Config) (*tls.Config, error) {
|
||||
}
|
||||
|
||||
return &tls.Config{
|
||||
// A tls.Config must have Certificates or GetCertificate
|
||||
// set, in order to be accepted by tls.Listen and quic.Listen.
|
||||
// TODO: remove this once the standard library allows a tls.Config with
|
||||
// only GetConfigForClient set.
|
||||
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
return nil, fmt.Errorf("all certificates configured via GetConfigForClient")
|
||||
},
|
||||
GetConfigForClient: configMap.GetConfigForClient,
|
||||
}, nil
|
||||
}
|
||||
@@ -476,6 +333,14 @@ func assertConfigsCompatible(cfg1, cfg2 *Config) error {
|
||||
if c1.ClientAuth != c2.ClientAuth {
|
||||
return fmt.Errorf("client authentication policy mismatch")
|
||||
}
|
||||
if c1.ClientAuth != tls.NoClientCert && c2.ClientAuth != tls.NoClientCert && c1.ClientCAs != c2.ClientCAs {
|
||||
// Two hosts defined on the same listener are not compatible if they
|
||||
// have ClientAuth enabled, because there's no guarantee beyond the
|
||||
// hostname which config will be used (because SNI only has server name).
|
||||
// To prevent clients from bypassing authentication, require that
|
||||
// ClientAuth be configured in an unambiguous manner.
|
||||
return fmt.Errorf("multiple hosts requiring client authentication ambiguously configured")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -511,7 +376,7 @@ func SetDefaultTLSParams(config *Config) {
|
||||
|
||||
// Set default protocol min and max versions - must balance compatibility and security
|
||||
if config.ProtocolMinVersion == 0 {
|
||||
config.ProtocolMinVersion = tls.VersionTLS11
|
||||
config.ProtocolMinVersion = tls.VersionTLS12
|
||||
}
|
||||
if config.ProtocolMaxVersion == 0 {
|
||||
config.ProtocolMaxVersion = tls.VersionTLS12
|
||||
@@ -522,23 +387,35 @@ func SetDefaultTLSParams(config *Config) {
|
||||
}
|
||||
|
||||
// Map of supported key types
|
||||
var supportedKeyTypes = map[string]acme.KeyType{
|
||||
"P384": acme.EC384,
|
||||
"P256": acme.EC256,
|
||||
"RSA8192": acme.RSA8192,
|
||||
"RSA4096": acme.RSA4096,
|
||||
"RSA2048": acme.RSA2048,
|
||||
var supportedKeyTypes = map[string]certcrypto.KeyType{
|
||||
"P384": certcrypto.EC384,
|
||||
"P256": certcrypto.EC256,
|
||||
"RSA8192": certcrypto.RSA8192,
|
||||
"RSA4096": certcrypto.RSA4096,
|
||||
"RSA2048": certcrypto.RSA2048,
|
||||
}
|
||||
|
||||
// Map of supported protocols.
|
||||
// SupportedProtocols is a map of supported protocols.
|
||||
// HTTP/2 only supports TLS 1.2 and higher.
|
||||
var supportedProtocols = map[string]uint16{
|
||||
// If updating this map, also update tlsProtocolStringToMap in caddyhttp/fastcgi/fastcgi.go
|
||||
var SupportedProtocols = map[string]uint16{
|
||||
"tls1.0": tls.VersionTLS10,
|
||||
"tls1.1": tls.VersionTLS11,
|
||||
"tls1.2": tls.VersionTLS12,
|
||||
}
|
||||
|
||||
// Map of supported ciphers, used only for parsing config.
|
||||
// GetSupportedProtocolName returns the protocol name
|
||||
func GetSupportedProtocolName(protocol uint16) (string, error) {
|
||||
for k, v := range SupportedProtocols {
|
||||
if v == protocol {
|
||||
return k, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("name: unsuported protocol")
|
||||
}
|
||||
|
||||
// SupportedCiphersMap has supported ciphers, used only for parsing config.
|
||||
//
|
||||
// Note that, at time of writing, HTTP/2 blacklists 276 cipher suites,
|
||||
// including all but four of the suites below (the four GCM suites).
|
||||
@@ -548,7 +425,7 @@ var supportedProtocols = map[string]uint16{
|
||||
// it is always added (even though it is not technically a cipher suite).
|
||||
//
|
||||
// This map, like any map, is NOT ORDERED. Do not range over this map.
|
||||
var supportedCiphersMap = map[string]uint16{
|
||||
var SupportedCiphersMap = map[string]uint16{
|
||||
"ECDHE-ECDSA-AES256-GCM-SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||
"ECDHE-RSA-AES256-GCM-SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
"ECDHE-ECDSA-AES128-GCM-SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
@@ -565,6 +442,17 @@ var supportedCiphersMap = map[string]uint16{
|
||||
"RSA-3DES-EDE-CBC-SHA": tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA,
|
||||
}
|
||||
|
||||
// GetSupportedCipherName returns the cipher name
|
||||
func GetSupportedCipherName(cipher uint16) (string, error) {
|
||||
for k, v := range SupportedCiphersMap {
|
||||
if v == cipher {
|
||||
return k, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("name: unsuported cipher")
|
||||
}
|
||||
|
||||
// List of all the ciphers we want to use by default
|
||||
var defaultCiphers = []uint16{
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||
@@ -577,8 +465,6 @@ var defaultCiphers = []uint16{
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
|
||||
tls.TLS_RSA_WITH_AES_256_CBC_SHA,
|
||||
tls.TLS_RSA_WITH_AES_128_CBC_SHA,
|
||||
}
|
||||
|
||||
// List of ciphers we should prefer if native AESNI support is missing
|
||||
@@ -593,8 +479,6 @@ var defaultCiphersNonAESNI = []uint16{
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
|
||||
tls.TLS_RSA_WITH_AES_256_CBC_SHA,
|
||||
tls.TLS_RSA_WITH_AES_128_CBC_SHA,
|
||||
}
|
||||
|
||||
// getPreferredDefaultCiphers returns an appropriate cipher suite to use, depending on
|
||||
@@ -602,7 +486,7 @@ var defaultCiphersNonAESNI = []uint16{
|
||||
//
|
||||
// See https://github.com/mholt/caddy/issues/1674
|
||||
func getPreferredDefaultCiphers() []uint16 {
|
||||
if aesnicheck.HasAESNI() {
|
||||
if cpuid.CPU.AesNi() {
|
||||
return defaultCiphers
|
||||
}
|
||||
|
||||
@@ -629,24 +513,6 @@ var defaultCurves = []tls.CurveID{
|
||||
tls.CurveP256,
|
||||
}
|
||||
|
||||
const (
|
||||
// HTTPChallengePort is the officially designated port for
|
||||
// the HTTP challenge according to the ACME spec.
|
||||
HTTPChallengePort = "80"
|
||||
|
||||
// TLSSNIChallengePort is the officially designated port for
|
||||
// the TLS-SNI challenge according to the ACME spec.
|
||||
TLSSNIChallengePort = "443"
|
||||
|
||||
// DefaultHTTPAlternatePort is the port on which the ACME
|
||||
// client will open a listener and solve the HTTP challenge.
|
||||
// If this alternate port is used instead of the default
|
||||
// port, then whatever is listening on the default port must
|
||||
// be capable of proxying or forwarding the request to this
|
||||
// alternate port.
|
||||
DefaultHTTPAlternatePort = "5033"
|
||||
|
||||
// CertCacheInstStorageKey is the name of the key for
|
||||
// accessing the certificate storage on the *caddy.Instance.
|
||||
CertCacheInstStorageKey = "tls_cert_cache"
|
||||
)
|
||||
// CertCacheInstStorageKey is the name of the key for
|
||||
// accessing the certificate storage on the *caddy.Instance.
|
||||
const CertCacheInstStorageKey = "tls_cert_cache"
|
||||
|
||||
+2
-121
@@ -16,12 +16,10 @@ package caddytls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/codahale/aesnicheck"
|
||||
"github.com/klauspost/cpuid"
|
||||
)
|
||||
|
||||
func TestConvertTLSConfigProtocolVersions(t *testing.T) {
|
||||
@@ -98,7 +96,7 @@ func TestConvertTLSConfigCipherSuites(t *testing.T) {
|
||||
|
||||
func TestGetPreferredDefaultCiphers(t *testing.T) {
|
||||
expectedCiphers := defaultCiphers
|
||||
if !aesnicheck.HasAESNI() {
|
||||
if !cpuid.CPU.AesNi() {
|
||||
expectedCiphers = defaultCiphersNonAESNI
|
||||
}
|
||||
|
||||
@@ -110,120 +108,3 @@ func TestGetPreferredDefaultCiphers(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStorageForNoURL(t *testing.T) {
|
||||
c := &Config{}
|
||||
if _, err := c.StorageFor(""); err == nil {
|
||||
t.Fatal("Expected error on empty URL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStorageForLowercasesAndPrefixesScheme(t *testing.T) {
|
||||
resultStr := ""
|
||||
RegisterStorageProvider("fake-TestStorageForLowercasesAndPrefixesScheme", func(caURL *url.URL) (Storage, error) {
|
||||
resultStr = caURL.String()
|
||||
return nil, nil
|
||||
})
|
||||
c := &Config{
|
||||
StorageProvider: "fake-TestStorageForLowercasesAndPrefixesScheme",
|
||||
}
|
||||
if _, err := c.StorageFor("EXAMPLE.COM/BLAH"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resultStr != "https://example.com/blah" {
|
||||
t.Fatalf("Unexpected CA URL string: %v", resultStr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStorageForBadURL(t *testing.T) {
|
||||
c := &Config{}
|
||||
if _, err := c.StorageFor("http://192.168.0.%31/"); err == nil {
|
||||
t.Fatal("Expected error for bad URL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStorageForDefault(t *testing.T) {
|
||||
c := &Config{}
|
||||
s, err := c.StorageFor("example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, ok := s.(*FileStorage); !ok {
|
||||
t.Fatalf("Unexpected storage type: %#v", s)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStorageForCustom(t *testing.T) {
|
||||
storage := fakeStorage("fake-TestStorageForCustom")
|
||||
RegisterStorageProvider("fake-TestStorageForCustom", func(caURL *url.URL) (Storage, error) { return storage, nil })
|
||||
c := &Config{
|
||||
StorageProvider: "fake-TestStorageForCustom",
|
||||
}
|
||||
s, err := c.StorageFor("example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if s != storage {
|
||||
t.Fatal("Unexpected storage")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStorageForCustomError(t *testing.T) {
|
||||
RegisterStorageProvider("fake-TestStorageForCustomError", func(caURL *url.URL) (Storage, error) { return nil, errors.New("some error") })
|
||||
c := &Config{
|
||||
StorageProvider: "fake-TestStorageForCustomError",
|
||||
}
|
||||
if _, err := c.StorageFor("example.com"); err == nil {
|
||||
t.Fatal("Expecting error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStorageForCustomNil(t *testing.T) {
|
||||
// Should fall through to the default
|
||||
c := &Config{StorageProvider: ""}
|
||||
s, err := c.StorageFor("example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, ok := s.(*FileStorage); !ok {
|
||||
t.Fatalf("Unexpected storage type: %#v", s)
|
||||
}
|
||||
}
|
||||
|
||||
type fakeStorage string
|
||||
|
||||
func (s fakeStorage) SiteExists(domain string) (bool, error) {
|
||||
panic("no impl")
|
||||
}
|
||||
|
||||
func (s fakeStorage) LoadSite(domain string) (*SiteData, error) {
|
||||
panic("no impl")
|
||||
}
|
||||
|
||||
func (s fakeStorage) StoreSite(domain string, data *SiteData) error {
|
||||
panic("no impl")
|
||||
}
|
||||
|
||||
func (s fakeStorage) DeleteSite(domain string) error {
|
||||
panic("no impl")
|
||||
}
|
||||
|
||||
func (s fakeStorage) TryLock(domain string) (Waiter, error) {
|
||||
panic("no impl")
|
||||
}
|
||||
|
||||
func (s fakeStorage) Unlock(domain string) error {
|
||||
panic("no impl")
|
||||
}
|
||||
|
||||
func (s fakeStorage) LoadUser(email string) (*UserData, error) {
|
||||
panic("no impl")
|
||||
}
|
||||
|
||||
func (s fakeStorage) StoreUser(email string, data *UserData) error {
|
||||
panic("no impl")
|
||||
}
|
||||
|
||||
func (s fakeStorage) MostRecentUserEmail() string {
|
||||
panic("no impl")
|
||||
}
|
||||
|
||||
+2
-240
@@ -15,249 +15,20 @@
|
||||
package caddytls
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"math/big"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ocsp"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/xenolf/lego/acme"
|
||||
)
|
||||
|
||||
// loadPrivateKey loads a PEM-encoded ECC/RSA private key from an array of bytes.
|
||||
func loadPrivateKey(keyBytes []byte) (crypto.PrivateKey, error) {
|
||||
keyBlock, _ := pem.Decode(keyBytes)
|
||||
|
||||
switch keyBlock.Type {
|
||||
case "RSA PRIVATE KEY":
|
||||
return x509.ParsePKCS1PrivateKey(keyBlock.Bytes)
|
||||
case "EC PRIVATE KEY":
|
||||
return x509.ParseECPrivateKey(keyBlock.Bytes)
|
||||
}
|
||||
|
||||
return nil, errors.New("unknown private key type")
|
||||
}
|
||||
|
||||
// savePrivateKey saves a PEM-encoded ECC/RSA private key to an array of bytes.
|
||||
func savePrivateKey(key crypto.PrivateKey) ([]byte, error) {
|
||||
var pemType string
|
||||
var keyBytes []byte
|
||||
switch key := key.(type) {
|
||||
case *ecdsa.PrivateKey:
|
||||
var err error
|
||||
pemType = "EC"
|
||||
keyBytes, err = x509.MarshalECPrivateKey(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case *rsa.PrivateKey:
|
||||
pemType = "RSA"
|
||||
keyBytes = x509.MarshalPKCS1PrivateKey(key)
|
||||
}
|
||||
|
||||
pemKey := pem.Block{Type: pemType + " PRIVATE KEY", Bytes: keyBytes}
|
||||
return pem.EncodeToMemory(&pemKey), nil
|
||||
}
|
||||
|
||||
// stapleOCSP staples OCSP information to cert for hostname name.
|
||||
// If you have it handy, you should pass in the PEM-encoded certificate
|
||||
// bundle; otherwise the DER-encoded cert will have to be PEM-encoded.
|
||||
// If you don't have the PEM blocks already, just pass in nil.
|
||||
//
|
||||
// Errors here are not necessarily fatal, it could just be that the
|
||||
// certificate doesn't have an issuer URL.
|
||||
func stapleOCSP(cert *Certificate, pemBundle []byte) error {
|
||||
if pemBundle == nil {
|
||||
// The function in the acme package that gets OCSP requires a PEM-encoded cert
|
||||
bundle := new(bytes.Buffer)
|
||||
for _, derBytes := range cert.Certificate.Certificate {
|
||||
pem.Encode(bundle, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
|
||||
}
|
||||
pemBundle = bundle.Bytes()
|
||||
}
|
||||
|
||||
var ocspBytes []byte
|
||||
var ocspResp *ocsp.Response
|
||||
var ocspErr error
|
||||
var gotNewOCSP bool
|
||||
|
||||
// First try to load OCSP staple from storage and see if
|
||||
// we can still use it.
|
||||
// TODO: Use Storage interface instead of disk directly
|
||||
var ocspFileNamePrefix string
|
||||
if len(cert.Names) > 0 {
|
||||
ocspFileNamePrefix = cert.Names[0] + "-"
|
||||
}
|
||||
ocspFileName := ocspFileNamePrefix + fastHash(pemBundle)
|
||||
ocspCachePath := filepath.Join(ocspFolder, ocspFileName)
|
||||
cachedOCSP, err := ioutil.ReadFile(ocspCachePath)
|
||||
if err == nil {
|
||||
resp, err := ocsp.ParseResponse(cachedOCSP, nil)
|
||||
if err == nil {
|
||||
if freshOCSP(resp) {
|
||||
// staple is still fresh; use it
|
||||
ocspBytes = cachedOCSP
|
||||
ocspResp = resp
|
||||
}
|
||||
} else {
|
||||
// invalid contents; delete the file
|
||||
// (we do this independently of the maintenance routine because
|
||||
// in this case we know for sure this should be a staple file
|
||||
// because we loaded it by name, whereas the maintenance routine
|
||||
// just iterates the list of files, even if somehow a non-staple
|
||||
// file gets in the folder. in this case we are sure it is corrupt.)
|
||||
err := os.Remove(ocspCachePath)
|
||||
if err != nil {
|
||||
log.Printf("[WARNING] Unable to delete invalid OCSP staple file: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we couldn't get a fresh staple by reading the cache,
|
||||
// then we need to request it from the OCSP responder
|
||||
if ocspResp == nil || len(ocspBytes) == 0 {
|
||||
ocspBytes, ocspResp, ocspErr = acme.GetOCSPForCert(pemBundle)
|
||||
if ocspErr != nil {
|
||||
// An error here is not a problem because a certificate may simply
|
||||
// not contain a link to an OCSP server. But we should log it anyway.
|
||||
// There's nothing else we can do to get OCSP for this certificate,
|
||||
// so we can return here with the error.
|
||||
return fmt.Errorf("no OCSP stapling for %v: %v", cert.Names, ocspErr)
|
||||
}
|
||||
gotNewOCSP = true
|
||||
}
|
||||
|
||||
// By now, we should have a response. If good, staple it to
|
||||
// the certificate. If the OCSP response was not loaded from
|
||||
// storage, we persist it for next time.
|
||||
if ocspResp.Status == ocsp.Good {
|
||||
if ocspResp.NextUpdate.After(cert.NotAfter) {
|
||||
// uh oh, this OCSP response expires AFTER the certificate does, that's kinda bogus.
|
||||
// it was the reason a lot of Symantec-validated sites (not Caddy) went down
|
||||
// in October 2017. https://twitter.com/mattiasgeniar/status/919432824708648961
|
||||
return fmt.Errorf("invalid: OCSP response for %v valid after certificate expiration (%s)",
|
||||
cert.Names, cert.NotAfter.Sub(ocspResp.NextUpdate))
|
||||
}
|
||||
cert.Certificate.OCSPStaple = ocspBytes
|
||||
cert.OCSP = ocspResp
|
||||
if gotNewOCSP {
|
||||
err := os.MkdirAll(filepath.Join(caddy.AssetsPath(), "ocsp"), 0700)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to make OCSP staple path for %v: %v", cert.Names, err)
|
||||
}
|
||||
err = ioutil.WriteFile(ocspCachePath, ocspBytes, 0644)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to write OCSP staple file for %v: %v", cert.Names, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// makeSelfSignedCert makes a self-signed certificate according
|
||||
// to the parameters in config. It then caches the certificate
|
||||
// in our cache.
|
||||
func makeSelfSignedCert(config *Config) error {
|
||||
// start by generating private key
|
||||
var privKey interface{}
|
||||
var err error
|
||||
switch config.KeyType {
|
||||
case "", acme.EC256:
|
||||
privKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
case acme.EC384:
|
||||
privKey, err = ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
|
||||
case acme.RSA2048:
|
||||
privKey, err = rsa.GenerateKey(rand.Reader, 2048)
|
||||
case acme.RSA4096:
|
||||
privKey, err = rsa.GenerateKey(rand.Reader, 4096)
|
||||
case acme.RSA8192:
|
||||
privKey, err = rsa.GenerateKey(rand.Reader, 8192)
|
||||
default:
|
||||
return fmt.Errorf("cannot generate private key; unknown key type %v", config.KeyType)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate private key: %v", err)
|
||||
}
|
||||
|
||||
// create certificate structure with proper values
|
||||
notBefore := time.Now()
|
||||
notAfter := notBefore.Add(24 * time.Hour * 7)
|
||||
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
||||
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate serial number: %v", err)
|
||||
}
|
||||
cert := &x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{Organization: []string{"Caddy Self-Signed"}},
|
||||
NotBefore: notBefore,
|
||||
NotAfter: notAfter,
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
}
|
||||
if ip := net.ParseIP(config.Hostname); ip != nil {
|
||||
cert.IPAddresses = append(cert.IPAddresses, ip)
|
||||
} else {
|
||||
cert.DNSNames = append(cert.DNSNames, config.Hostname)
|
||||
}
|
||||
|
||||
publicKey := func(privKey interface{}) interface{} {
|
||||
switch k := privKey.(type) {
|
||||
case *rsa.PrivateKey:
|
||||
return &k.PublicKey
|
||||
case *ecdsa.PrivateKey:
|
||||
return &k.PublicKey
|
||||
default:
|
||||
return errors.New("unknown key type")
|
||||
}
|
||||
}
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, cert, cert, publicKey(privKey), privKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create certificate: %v", err)
|
||||
}
|
||||
|
||||
chain := [][]byte{derBytes}
|
||||
|
||||
config.cacheCertificate(Certificate{
|
||||
Certificate: tls.Certificate{
|
||||
Certificate: chain,
|
||||
PrivateKey: privKey,
|
||||
Leaf: cert,
|
||||
},
|
||||
Names: cert.DNSNames,
|
||||
NotAfter: cert.NotAfter,
|
||||
Hash: hashCertificateChain(chain),
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RotateSessionTicketKeys rotates the TLS session ticket keys
|
||||
// on cfg every TicketRotateInterval. It spawns a new goroutine so
|
||||
// this function does NOT block. It returns a channel you should
|
||||
// close when you are ready to stop the key rotation, like when the
|
||||
// server using cfg is no longer running.
|
||||
//
|
||||
// TODO: See about moving this into CertMagic and using its Storage
|
||||
func RotateSessionTicketKeys(cfg *tls.Config) chan struct{} {
|
||||
ch := make(chan struct{})
|
||||
ticker := time.NewTicker(TicketRotateInterval)
|
||||
@@ -331,15 +102,6 @@ func standaloneTLSTicketKeyRotation(c *tls.Config, ticker *time.Ticker, exitChan
|
||||
}
|
||||
}
|
||||
|
||||
// fastHash hashes input using a hashing algorithm that
|
||||
// is fast, and returns the hash as a hex-encoded string.
|
||||
// Do not use this for cryptographic purposes.
|
||||
func fastHash(input []byte) string {
|
||||
h := fnv.New32a()
|
||||
h.Write([]byte(input))
|
||||
return fmt.Sprintf("%x", h.Sum32())
|
||||
}
|
||||
|
||||
const (
|
||||
// NumTickets is how many tickets to hold and consider
|
||||
// to decrypt TLS sessions.
|
||||
|
||||
@@ -15,83 +15,11 @@
|
||||
package caddytls
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSaveAndLoadRSAPrivateKey(t *testing.T) {
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 128) // make tests faster; small key size OK for testing
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// test save
|
||||
savedBytes, err := savePrivateKey(privateKey)
|
||||
if err != nil {
|
||||
t.Fatal("error saving private key:", err)
|
||||
}
|
||||
|
||||
// test load
|
||||
loadedKey, err := loadPrivateKey(savedBytes)
|
||||
if err != nil {
|
||||
t.Error("error loading private key:", err)
|
||||
}
|
||||
|
||||
// verify loaded key is correct
|
||||
if !PrivateKeysSame(privateKey, loadedKey) {
|
||||
t.Error("Expected key bytes to be the same, but they weren't")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveAndLoadECCPrivateKey(t *testing.T) {
|
||||
privateKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// test save
|
||||
savedBytes, err := savePrivateKey(privateKey)
|
||||
if err != nil {
|
||||
t.Fatal("error saving private key:", err)
|
||||
}
|
||||
|
||||
// test load
|
||||
loadedKey, err := loadPrivateKey(savedBytes)
|
||||
if err != nil {
|
||||
t.Error("error loading private key:", err)
|
||||
}
|
||||
|
||||
// verify loaded key is correct
|
||||
if !PrivateKeysSame(privateKey, loadedKey) {
|
||||
t.Error("Expected key bytes to be the same, but they weren't")
|
||||
}
|
||||
}
|
||||
|
||||
// PrivateKeysSame compares the bytes of a and b and returns true if they are the same.
|
||||
func PrivateKeysSame(a, b crypto.PrivateKey) bool {
|
||||
return bytes.Equal(PrivateKeyBytes(a), PrivateKeyBytes(b))
|
||||
}
|
||||
|
||||
// PrivateKeyBytes returns the bytes of DER-encoded key.
|
||||
func PrivateKeyBytes(key crypto.PrivateKey) []byte {
|
||||
var keyBytes []byte
|
||||
switch key := key.(type) {
|
||||
case *rsa.PrivateKey:
|
||||
keyBytes = x509.MarshalPKCS1PrivateKey(key)
|
||||
case *ecdsa.PrivateKey:
|
||||
keyBytes, _ = x509.MarshalECPrivateKey(key)
|
||||
}
|
||||
return keyBytes
|
||||
}
|
||||
|
||||
func TestStandaloneTLSTicketKeyRotation(t *testing.T) {
|
||||
type syncPkt struct {
|
||||
ticketKey [32]byte
|
||||
|
||||
@@ -1,276 +0,0 @@
|
||||
// Copyright 2015 Light Code Labs, LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package caddytls
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
)
|
||||
|
||||
func init() {
|
||||
RegisterStorageProvider("file", NewFileStorage)
|
||||
}
|
||||
|
||||
// storageBasePath is the root path in which all TLS/ACME assets are
|
||||
// stored. Do not change this value during the lifetime of the program.
|
||||
var storageBasePath = filepath.Join(caddy.AssetsPath(), "acme")
|
||||
|
||||
// NewFileStorage is a StorageConstructor function that creates a new
|
||||
// Storage instance backed by the local disk. The resulting Storage
|
||||
// instance is guaranteed to be non-nil if there is no error.
|
||||
func NewFileStorage(caURL *url.URL) (Storage, error) {
|
||||
storage := &FileStorage{Path: filepath.Join(storageBasePath, caURL.Host)}
|
||||
storage.Locker = &fileStorageLock{caURL: caURL.Host, storage: storage}
|
||||
return storage, nil
|
||||
}
|
||||
|
||||
// FileStorage facilitates forming file paths derived from a root
|
||||
// directory. It is used to get file paths in a consistent,
|
||||
// cross-platform way or persisting ACME assets on the file system.
|
||||
type FileStorage struct {
|
||||
Path string
|
||||
Locker
|
||||
}
|
||||
|
||||
// sites gets the directory that stores site certificate and keys.
|
||||
func (s *FileStorage) sites() string {
|
||||
return filepath.Join(s.Path, "sites")
|
||||
}
|
||||
|
||||
// site returns the path to the folder containing assets for domain.
|
||||
func (s *FileStorage) site(domain string) string {
|
||||
domain = strings.ToLower(domain)
|
||||
return filepath.Join(s.sites(), domain)
|
||||
}
|
||||
|
||||
// siteCertFile returns the path to the certificate file for domain.
|
||||
func (s *FileStorage) siteCertFile(domain string) string {
|
||||
domain = strings.ToLower(domain)
|
||||
return filepath.Join(s.site(domain), domain+".crt")
|
||||
}
|
||||
|
||||
// siteKeyFile returns the path to domain's private key file.
|
||||
func (s *FileStorage) siteKeyFile(domain string) string {
|
||||
domain = strings.ToLower(domain)
|
||||
return filepath.Join(s.site(domain), domain+".key")
|
||||
}
|
||||
|
||||
// siteMetaFile returns the path to the domain's asset metadata file.
|
||||
func (s *FileStorage) siteMetaFile(domain string) string {
|
||||
domain = strings.ToLower(domain)
|
||||
return filepath.Join(s.site(domain), domain+".json")
|
||||
}
|
||||
|
||||
// users gets the directory that stores account folders.
|
||||
func (s *FileStorage) users() string {
|
||||
return filepath.Join(s.Path, "users")
|
||||
}
|
||||
|
||||
// user gets the account folder for the user with email
|
||||
func (s *FileStorage) user(email string) string {
|
||||
if email == "" {
|
||||
email = emptyEmail
|
||||
}
|
||||
email = strings.ToLower(email)
|
||||
return filepath.Join(s.users(), email)
|
||||
}
|
||||
|
||||
// emailUsername returns the username portion of an email address (part before
|
||||
// '@') or the original input if it can't find the "@" symbol.
|
||||
func emailUsername(email string) string {
|
||||
at := strings.Index(email, "@")
|
||||
if at == -1 {
|
||||
return email
|
||||
} else if at == 0 {
|
||||
return email[1:]
|
||||
}
|
||||
return email[:at]
|
||||
}
|
||||
|
||||
// userRegFile gets the path to the registration file for the user with the
|
||||
// given email address.
|
||||
func (s *FileStorage) userRegFile(email string) string {
|
||||
if email == "" {
|
||||
email = emptyEmail
|
||||
}
|
||||
email = strings.ToLower(email)
|
||||
fileName := emailUsername(email)
|
||||
if fileName == "" {
|
||||
fileName = "registration"
|
||||
}
|
||||
return filepath.Join(s.user(email), fileName+".json")
|
||||
}
|
||||
|
||||
// userKeyFile gets the path to the private key file for the user with the
|
||||
// given email address.
|
||||
func (s *FileStorage) userKeyFile(email string) string {
|
||||
if email == "" {
|
||||
email = emptyEmail
|
||||
}
|
||||
email = strings.ToLower(email)
|
||||
fileName := emailUsername(email)
|
||||
if fileName == "" {
|
||||
fileName = "private"
|
||||
}
|
||||
return filepath.Join(s.user(email), fileName+".key")
|
||||
}
|
||||
|
||||
// readFile abstracts a simple ioutil.ReadFile, making sure to return an
|
||||
// ErrNotExist instance when the file is not found.
|
||||
func (s *FileStorage) readFile(file string) ([]byte, error) {
|
||||
b, err := ioutil.ReadFile(file)
|
||||
if os.IsNotExist(err) {
|
||||
return nil, ErrNotExist(err)
|
||||
}
|
||||
return b, err
|
||||
}
|
||||
|
||||
// SiteExists implements Storage.SiteExists by checking for the presence of
|
||||
// cert and key files.
|
||||
func (s *FileStorage) SiteExists(domain string) (bool, error) {
|
||||
_, err := os.Stat(s.siteCertFile(domain))
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil
|
||||
} else if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
_, err = os.Stat(s.siteKeyFile(domain))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// LoadSite implements Storage.LoadSite by loading it from disk. If it is not
|
||||
// present, an instance of ErrNotExist is returned.
|
||||
func (s *FileStorage) LoadSite(domain string) (*SiteData, error) {
|
||||
var err error
|
||||
siteData := new(SiteData)
|
||||
siteData.Cert, err = s.readFile(s.siteCertFile(domain))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
siteData.Key, err = s.readFile(s.siteKeyFile(domain))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
siteData.Meta, err = s.readFile(s.siteMetaFile(domain))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return siteData, nil
|
||||
}
|
||||
|
||||
// StoreSite implements Storage.StoreSite by writing it to disk. The base
|
||||
// directories needed for the file are automatically created as needed.
|
||||
func (s *FileStorage) StoreSite(domain string, data *SiteData) error {
|
||||
err := os.MkdirAll(s.site(domain), 0700)
|
||||
if err != nil {
|
||||
return fmt.Errorf("making site directory: %v", err)
|
||||
}
|
||||
err = ioutil.WriteFile(s.siteCertFile(domain), data.Cert, 0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing certificate file: %v", err)
|
||||
}
|
||||
err = ioutil.WriteFile(s.siteKeyFile(domain), data.Key, 0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing key file: %v", err)
|
||||
}
|
||||
err = ioutil.WriteFile(s.siteMetaFile(domain), data.Meta, 0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing cert meta file: %v", err)
|
||||
}
|
||||
log.Printf("[INFO][%v] Certificate written to disk: %v", domain, s.siteCertFile(domain))
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteSite implements Storage.DeleteSite by deleting just the cert from
|
||||
// disk. If it is not present, an instance of ErrNotExist is returned.
|
||||
func (s *FileStorage) DeleteSite(domain string) error {
|
||||
err := os.Remove(s.siteCertFile(domain))
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return ErrNotExist(err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadUser implements Storage.LoadUser by loading it from disk. If it is not
|
||||
// present, an instance of ErrNotExist is returned.
|
||||
func (s *FileStorage) LoadUser(email string) (*UserData, error) {
|
||||
var err error
|
||||
userData := new(UserData)
|
||||
userData.Reg, err = s.readFile(s.userRegFile(email))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userData.Key, err = s.readFile(s.userKeyFile(email))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return userData, nil
|
||||
}
|
||||
|
||||
// StoreUser implements Storage.StoreUser by writing it to disk. The base
|
||||
// directories needed for the file are automatically created as needed.
|
||||
func (s *FileStorage) StoreUser(email string, data *UserData) error {
|
||||
err := os.MkdirAll(s.user(email), 0700)
|
||||
if err != nil {
|
||||
return fmt.Errorf("making user directory: %v", err)
|
||||
}
|
||||
err = ioutil.WriteFile(s.userRegFile(email), data.Reg, 0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing user registration file: %v", err)
|
||||
}
|
||||
err = ioutil.WriteFile(s.userKeyFile(email), data.Key, 0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing user key file: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MostRecentUserEmail implements Storage.MostRecentUserEmail by finding the
|
||||
// most recently written sub directory in the users' directory. It is named
|
||||
// after the email address. This corresponds to the most recent call to
|
||||
// StoreUser.
|
||||
func (s *FileStorage) MostRecentUserEmail() string {
|
||||
userDirs, err := ioutil.ReadDir(s.users())
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
var mostRecent os.FileInfo
|
||||
for _, dir := range userDirs {
|
||||
if !dir.IsDir() {
|
||||
continue
|
||||
}
|
||||
if mostRecent == nil || dir.ModTime().After(mostRecent.ModTime()) {
|
||||
mostRecent = dir
|
||||
}
|
||||
}
|
||||
if mostRecent != nil {
|
||||
return mostRecent.Name()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -1,20 +0,0 @@
|
||||
// Copyright 2015 Light Code Labs, LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package caddytls
|
||||
|
||||
// *********************************** NOTE ********************************
|
||||
// Due to circular package dependencies with the storagetest sub package and
|
||||
// the fact that we want to use that harness to test file storage, the tests
|
||||
// for file storage are done in the storagetest package.
|
||||
@@ -1,127 +0,0 @@
|
||||
// Copyright 2015 Light Code Labs, LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package caddytls
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// be sure to remove lock files when exiting the process!
|
||||
caddy.OnProcessExit = append(caddy.OnProcessExit, func() {
|
||||
fileStorageNameLocksMu.Lock()
|
||||
defer fileStorageNameLocksMu.Unlock()
|
||||
for key, fw := range fileStorageNameLocks {
|
||||
os.Remove(fw.filename)
|
||||
delete(fileStorageNameLocks, key)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// fileStorageLock facilitates ACME-related locking by using
|
||||
// the associated FileStorage, so multiple processes can coordinate
|
||||
// renewals on the certificates on a shared file system.
|
||||
type fileStorageLock struct {
|
||||
caURL string
|
||||
storage *FileStorage
|
||||
}
|
||||
|
||||
// TryLock attempts to get a lock for name, otherwise it returns
|
||||
// a Waiter value to wait until the other process is finished.
|
||||
func (s *fileStorageLock) TryLock(name string) (Waiter, error) {
|
||||
fileStorageNameLocksMu.Lock()
|
||||
defer fileStorageNameLocksMu.Unlock()
|
||||
|
||||
// see if lock already exists within this process
|
||||
fw, ok := fileStorageNameLocks[s.caURL+name]
|
||||
if ok {
|
||||
// lock already created within process, let caller wait on it
|
||||
return fw, nil
|
||||
}
|
||||
|
||||
// attempt to persist lock to disk by creating lock file
|
||||
fw = &fileWaiter{
|
||||
filename: s.storage.siteCertFile(name) + ".lock",
|
||||
wg: new(sync.WaitGroup),
|
||||
}
|
||||
// parent dir must exist
|
||||
if err := os.MkdirAll(s.storage.site(name), 0700); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
lf, err := os.OpenFile(fw.filename, os.O_CREATE|os.O_EXCL, 0644)
|
||||
if err != nil {
|
||||
if os.IsExist(err) {
|
||||
// another process has the lock; use it to wait
|
||||
return fw, nil
|
||||
}
|
||||
// otherwise, this was some unexpected error
|
||||
return nil, err
|
||||
}
|
||||
lf.Close()
|
||||
|
||||
// looks like we get the lock
|
||||
fw.wg.Add(1)
|
||||
fileStorageNameLocks[s.caURL+name] = fw
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Unlock unlocks name.
|
||||
func (s *fileStorageLock) Unlock(name string) error {
|
||||
fileStorageNameLocksMu.Lock()
|
||||
defer fileStorageNameLocksMu.Unlock()
|
||||
fw, ok := fileStorageNameLocks[s.caURL+name]
|
||||
if !ok {
|
||||
return fmt.Errorf("FileStorage: no lock to release for %s", name)
|
||||
}
|
||||
os.Remove(fw.filename)
|
||||
fw.wg.Done()
|
||||
delete(fileStorageNameLocks, s.caURL+name)
|
||||
return nil
|
||||
}
|
||||
|
||||
// fileWaiter waits for a file to disappear; it polls
|
||||
// the file system to check for the existence of a file.
|
||||
// It also has a WaitGroup which will be faster than
|
||||
// polling, for when locking need only happen within this
|
||||
// process.
|
||||
type fileWaiter struct {
|
||||
filename string
|
||||
wg *sync.WaitGroup
|
||||
}
|
||||
|
||||
// Wait waits until the lock is released.
|
||||
func (fw *fileWaiter) Wait() {
|
||||
start := time.Now()
|
||||
fw.wg.Wait()
|
||||
for time.Since(start) < 1*time.Hour {
|
||||
_, err := os.Stat(fw.filename)
|
||||
if os.IsNotExist(err) {
|
||||
return
|
||||
}
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
var fileStorageNameLocks = make(map[string]*fileWaiter) // keyed by CA + name
|
||||
var fileStorageNameLocksMu sync.Mutex
|
||||
|
||||
var _ Locker = &fileStorageLock{}
|
||||
var _ Waiter = &fileWaiter{}
|
||||
+40
-397
@@ -16,15 +16,10 @@ package caddytls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/mholt/caddy/telemetry"
|
||||
)
|
||||
|
||||
// configGroup is a type that keys configs by their hostname
|
||||
@@ -59,19 +54,14 @@ func (cg configGroup) getConfig(name string) *Config {
|
||||
}
|
||||
}
|
||||
|
||||
// try a config that serves all names (this
|
||||
// is basically the same as a config defined
|
||||
// for "*" -- I think -- but the above loop
|
||||
// doesn't try an empty string)
|
||||
// try a config that serves all names (the above
|
||||
// loop doesn't try empty string; for hosts defined
|
||||
// with only a port, for instance, like ":443") -
|
||||
// also known as the default config
|
||||
if config, ok := cg[""]; ok {
|
||||
return config
|
||||
}
|
||||
|
||||
// no matches, so just serve up a random config
|
||||
for _, config := range cg {
|
||||
return config
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -89,390 +79,43 @@ func (cg configGroup) GetConfigForClient(clientHello *tls.ClientHelloInfo) (*tls
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// GetCertificate gets a certificate to satisfy clientHello. In getting
|
||||
// the certificate, it abides the rules and settings defined in the
|
||||
// Config that matches clientHello.ServerName. It first checks the in-
|
||||
// memory cache, then, if the config enables "OnDemand", it accesses
|
||||
// disk, then accesses the network if it must obtain a new certificate
|
||||
// via ACME.
|
||||
//
|
||||
// This method is safe for use as a tls.Config.GetCertificate callback.
|
||||
func (cfg *Config) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
cert, err := cfg.getCertDuringHandshake(strings.ToLower(clientHello.ServerName), true, true)
|
||||
return &cert.Certificate, err
|
||||
// ClientHelloInfo is our own version of the standard lib's
|
||||
// tls.ClientHelloInfo. As of May 2018, any fields populated
|
||||
// by the Go standard library are not guaranteed to have their
|
||||
// values in the original order as on the wire.
|
||||
type ClientHelloInfo struct {
|
||||
Version uint16 `json:"version,omitempty"`
|
||||
CipherSuites []uint16 `json:"cipher_suites,omitempty"`
|
||||
Extensions []uint16 `json:"extensions,omitempty"`
|
||||
CompressionMethods []byte `json:"compression,omitempty"`
|
||||
Curves []tls.CurveID `json:"curves,omitempty"`
|
||||
Points []uint8 `json:"points,omitempty"`
|
||||
|
||||
// Whether a couple of fields are unknown; if not, the key will encode
|
||||
// differently to reflect that, as opposed to being known empty values.
|
||||
// (some fields may be unknown depending on what package is being used;
|
||||
// i.e. the Go standard lib doesn't expose some things)
|
||||
// (very important to NOT encode these to JSON)
|
||||
ExtensionsUnknown bool `json:"-"`
|
||||
CompressionMethodsUnknown bool `json:"-"`
|
||||
}
|
||||
|
||||
// getCertificate gets a certificate that matches name (a server name)
|
||||
// from the in-memory cache, according to the lookup table associated with
|
||||
// cfg. The lookup then points to a certificate in the Instance certificate
|
||||
// cache.
|
||||
//
|
||||
// If there is no exact match for name, it will be checked against names of
|
||||
// the form '*.example.com' (wildcard certificates) according to RFC 6125.
|
||||
// If a match is found, matched will be true. If no matches are found, matched
|
||||
// will be false and a "default" certificate will be returned with defaulted
|
||||
// set to true. If defaulted is false, then no certificates were available.
|
||||
//
|
||||
// The logic in this function is adapted from the Go standard library,
|
||||
// which is by the Go Authors.
|
||||
//
|
||||
// This function is safe for concurrent use.
|
||||
func (cfg *Config) getCertificate(name string) (cert Certificate, matched, defaulted bool) {
|
||||
var certKey string
|
||||
var ok bool
|
||||
|
||||
// Not going to trim trailing dots here since RFC 3546 says,
|
||||
// "The hostname is represented ... without a trailing dot."
|
||||
// Just normalize to lowercase.
|
||||
name = strings.ToLower(name)
|
||||
|
||||
cfg.certCache.RLock()
|
||||
defer cfg.certCache.RUnlock()
|
||||
|
||||
// exact match? great, let's use it
|
||||
if certKey, ok = cfg.Certificates[name]; ok {
|
||||
cert = cfg.certCache.cache[certKey]
|
||||
matched = true
|
||||
return
|
||||
// Key returns a standardized string form of the data in info,
|
||||
// useful for identifying duplicates.
|
||||
func (info ClientHelloInfo) Key() string {
|
||||
extensions, compressionMethods := "?", "?"
|
||||
if !info.ExtensionsUnknown {
|
||||
extensions = fmt.Sprintf("%x", info.Extensions)
|
||||
}
|
||||
|
||||
// try replacing labels in the name with wildcards until we get a match
|
||||
labels := strings.Split(name, ".")
|
||||
for i := range labels {
|
||||
labels[i] = "*"
|
||||
candidate := strings.Join(labels, ".")
|
||||
if certKey, ok = cfg.Certificates[candidate]; ok {
|
||||
cert = cfg.certCache.cache[certKey]
|
||||
matched = true
|
||||
return
|
||||
}
|
||||
if !info.CompressionMethodsUnknown {
|
||||
compressionMethods = fmt.Sprintf("%x", info.CompressionMethods)
|
||||
}
|
||||
|
||||
// check the certCache directly to see if the SNI name is
|
||||
// already the key of the certificate it wants! this is vital
|
||||
// for supporting the TLS-SNI challenge, since the tlsSNISolver
|
||||
// just puts the temporary certificate in the instance cache,
|
||||
// with no regard for configs; this also means that the SNI
|
||||
// can contain the hash of a specific cert (chain) it wants
|
||||
// and we will still be able to serve it up
|
||||
// (this behavior, by the way, could be controversial as to
|
||||
// whether it complies with RFC 6066 about SNI, but I think
|
||||
// it does soooo...)
|
||||
// NOTE/TODO: TLS-SNI challenge is changing, as of Jan. 2018
|
||||
// but what will be different, if it ever returns, is unclear
|
||||
if directCert, ok := cfg.certCache.cache[name]; ok {
|
||||
cert = directCert
|
||||
matched = true
|
||||
return
|
||||
}
|
||||
|
||||
// if nothing matches and SNI was not provided, use a random
|
||||
// certificate; at least there's a chance this older client
|
||||
// can connect, and in the future we won't need this provision
|
||||
// (if SNI is present, it's probably best to just raise a TLS
|
||||
// alert by not serving a certificate)
|
||||
if name == "" {
|
||||
for _, certKey := range cfg.Certificates {
|
||||
defaulted = true
|
||||
cert = cfg.certCache.cache[certKey]
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
return telemetry.FastHash([]byte(fmt.Sprintf("%x-%x-%s-%s-%x-%x",
|
||||
info.Version, info.CipherSuites, extensions,
|
||||
compressionMethods, info.Curves, info.Points)))
|
||||
}
|
||||
|
||||
// getCertDuringHandshake will get a certificate for name. It first tries
|
||||
// the in-memory cache. If no certificate for name is in the cache, the
|
||||
// config most closely corresponding to name will be loaded. If that config
|
||||
// allows it (OnDemand==true) and if loadIfNecessary == true, it goes to disk
|
||||
// to load it into the cache and serve it. If it's not on disk and if
|
||||
// obtainIfNecessary == true, the certificate will be obtained from the CA,
|
||||
// cached, and served. If obtainIfNecessary is true, then loadIfNecessary
|
||||
// must also be set to true. An error will be returned if and only if no
|
||||
// certificate is available.
|
||||
//
|
||||
// This function is safe for concurrent use.
|
||||
func (cfg *Config) getCertDuringHandshake(name string, loadIfNecessary, obtainIfNecessary bool) (Certificate, error) {
|
||||
// First check our in-memory cache to see if we've already loaded it
|
||||
cert, matched, defaulted := cfg.getCertificate(name)
|
||||
if matched {
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
// If OnDemand is enabled, then we might be able to load or
|
||||
// obtain a needed certificate
|
||||
if cfg.OnDemand && loadIfNecessary {
|
||||
// Then check to see if we have one on disk
|
||||
loadedCert, err := cfg.CacheManagedCertificate(name)
|
||||
if err == nil {
|
||||
loadedCert, err = cfg.handshakeMaintenance(name, loadedCert)
|
||||
if err != nil {
|
||||
log.Printf("[ERROR] Maintaining newly-loaded certificate for %s: %v", name, err)
|
||||
}
|
||||
return loadedCert, nil
|
||||
}
|
||||
if obtainIfNecessary {
|
||||
// By this point, we need to ask the CA for a certificate
|
||||
|
||||
name = strings.ToLower(name)
|
||||
|
||||
// Make sure the certificate should be obtained based on config
|
||||
err := cfg.checkIfCertShouldBeObtained(name)
|
||||
if err != nil {
|
||||
return Certificate{}, err
|
||||
}
|
||||
|
||||
// Name has to qualify for a certificate
|
||||
if !HostQualifies(name) {
|
||||
return cert, errors.New("hostname '" + name + "' does not qualify for certificate")
|
||||
}
|
||||
|
||||
// Obtain certificate from the CA
|
||||
return cfg.obtainOnDemandCertificate(name)
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to the default certificate if there is one
|
||||
if defaulted {
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
return Certificate{}, fmt.Errorf("no certificate available for %s", name)
|
||||
}
|
||||
|
||||
// checkIfCertShouldBeObtained checks to see if an on-demand tls certificate
|
||||
// should be obtained for a given domain based upon the config settings. If
|
||||
// a non-nil error is returned, do not issue a new certificate for name.
|
||||
func (cfg *Config) checkIfCertShouldBeObtained(name string) error {
|
||||
// If the "ask" URL is defined in the config, use to determine if a
|
||||
// cert should obtained
|
||||
if cfg.OnDemandState.AskURL != nil {
|
||||
return cfg.checkURLForObtainingNewCerts(name)
|
||||
}
|
||||
|
||||
// Otherwise use the limit defined by the "max_certs" setting
|
||||
return cfg.checkLimitsForObtainingNewCerts(name)
|
||||
}
|
||||
|
||||
func (cfg *Config) checkURLForObtainingNewCerts(name string) error {
|
||||
client := http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
return errors.New("following http redirects is not allowed")
|
||||
},
|
||||
}
|
||||
|
||||
// Copy the URL from the config in order to modify it for this request
|
||||
askURL := new(url.URL)
|
||||
*askURL = *cfg.OnDemandState.AskURL
|
||||
|
||||
query := askURL.Query()
|
||||
query.Set("domain", name)
|
||||
askURL.RawQuery = query.Encode()
|
||||
|
||||
resp, err := client.Get(askURL.String())
|
||||
if err != nil {
|
||||
return fmt.Errorf("error checking %v to deterine if certificate for hostname '%s' should be allowed: %v", cfg.OnDemandState.AskURL, name, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode > 299 {
|
||||
return fmt.Errorf("certificate for hostname '%s' not allowed, non-2xx status code %d returned from %v", name, resp.StatusCode, cfg.OnDemandState.AskURL)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkLimitsForObtainingNewCerts checks to see if name can be issued right
|
||||
// now according the maximum count defined in the configuration. If a non-nil
|
||||
// error is returned, do not issue a new certificate for name.
|
||||
func (cfg *Config) checkLimitsForObtainingNewCerts(name string) error {
|
||||
// User can set hard limit for number of certs for the process to issue
|
||||
if cfg.OnDemandState.MaxObtain > 0 &&
|
||||
atomic.LoadInt32(&cfg.OnDemandState.ObtainedCount) >= cfg.OnDemandState.MaxObtain {
|
||||
return fmt.Errorf("%s: maximum certificates issued (%d)", name, cfg.OnDemandState.MaxObtain)
|
||||
}
|
||||
|
||||
// Make sure name hasn't failed a challenge recently
|
||||
failedIssuanceMu.RLock()
|
||||
when, ok := failedIssuance[name]
|
||||
failedIssuanceMu.RUnlock()
|
||||
if ok {
|
||||
return fmt.Errorf("%s: throttled; refusing to issue cert since last attempt on %s failed", name, when.String())
|
||||
}
|
||||
|
||||
// Make sure, if we've issued a few certificates already, that we haven't
|
||||
// issued any recently
|
||||
lastIssueTimeMu.Lock()
|
||||
since := time.Since(lastIssueTime)
|
||||
lastIssueTimeMu.Unlock()
|
||||
if atomic.LoadInt32(&cfg.OnDemandState.ObtainedCount) >= 10 && since < 10*time.Minute {
|
||||
return fmt.Errorf("%s: throttled; last certificate was obtained %v ago", name, since)
|
||||
}
|
||||
|
||||
// Good to go 👍
|
||||
return nil
|
||||
}
|
||||
|
||||
// obtainOnDemandCertificate obtains a certificate for name for the given
|
||||
// name. If another goroutine has already started obtaining a cert for
|
||||
// name, it will wait and use what the other goroutine obtained.
|
||||
//
|
||||
// This function is safe for use by multiple concurrent goroutines.
|
||||
func (cfg *Config) obtainOnDemandCertificate(name string) (Certificate, error) {
|
||||
// We must protect this process from happening concurrently, so synchronize.
|
||||
obtainCertWaitChansMu.Lock()
|
||||
wait, ok := obtainCertWaitChans[name]
|
||||
if ok {
|
||||
// lucky us -- another goroutine is already obtaining the certificate.
|
||||
// wait for it to finish obtaining the cert and then we'll use it.
|
||||
obtainCertWaitChansMu.Unlock()
|
||||
<-wait
|
||||
return cfg.getCertDuringHandshake(name, true, false)
|
||||
}
|
||||
|
||||
// looks like it's up to us to do all the work and obtain the cert.
|
||||
// make a chan others can wait on if needed
|
||||
wait = make(chan struct{})
|
||||
obtainCertWaitChans[name] = wait
|
||||
obtainCertWaitChansMu.Unlock()
|
||||
|
||||
// obtain the certificate
|
||||
log.Printf("[INFO] Obtaining new certificate for %s", name)
|
||||
err := cfg.ObtainCert(name, false)
|
||||
|
||||
// immediately unblock anyone waiting for it; doing this in
|
||||
// a defer would risk deadlock because of the recursive call
|
||||
// to getCertDuringHandshake below when we return!
|
||||
obtainCertWaitChansMu.Lock()
|
||||
close(wait)
|
||||
delete(obtainCertWaitChans, name)
|
||||
obtainCertWaitChansMu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
// Failed to solve challenge, so don't allow another on-demand
|
||||
// issue for this name to be attempted for a little while.
|
||||
failedIssuanceMu.Lock()
|
||||
failedIssuance[name] = time.Now()
|
||||
go func(name string) {
|
||||
time.Sleep(5 * time.Minute)
|
||||
failedIssuanceMu.Lock()
|
||||
delete(failedIssuance, name)
|
||||
failedIssuanceMu.Unlock()
|
||||
}(name)
|
||||
failedIssuanceMu.Unlock()
|
||||
return Certificate{}, err
|
||||
}
|
||||
|
||||
// Success - update counters and stuff
|
||||
atomic.AddInt32(&cfg.OnDemandState.ObtainedCount, 1)
|
||||
lastIssueTimeMu.Lock()
|
||||
lastIssueTime = time.Now()
|
||||
lastIssueTimeMu.Unlock()
|
||||
|
||||
// certificate is already on disk; now just start over to load it and serve it
|
||||
return cfg.getCertDuringHandshake(name, true, false)
|
||||
}
|
||||
|
||||
// handshakeMaintenance performs a check on cert for expiration and OCSP
|
||||
// validity.
|
||||
//
|
||||
// This function is safe for use by multiple concurrent goroutines.
|
||||
func (cfg *Config) handshakeMaintenance(name string, cert Certificate) (Certificate, error) {
|
||||
// Check cert expiration
|
||||
timeLeft := cert.NotAfter.Sub(time.Now().UTC())
|
||||
if timeLeft < RenewDurationBefore {
|
||||
log.Printf("[INFO] Certificate for %v expires in %v; attempting renewal", cert.Names, timeLeft)
|
||||
return cfg.renewDynamicCertificate(name, cert)
|
||||
}
|
||||
|
||||
// Check OCSP staple validity
|
||||
if cert.OCSP != nil {
|
||||
refreshTime := cert.OCSP.ThisUpdate.Add(cert.OCSP.NextUpdate.Sub(cert.OCSP.ThisUpdate) / 2)
|
||||
if time.Now().After(refreshTime) {
|
||||
err := stapleOCSP(&cert, nil)
|
||||
if err != nil {
|
||||
// An error with OCSP stapling is not the end of the world, and in fact, is
|
||||
// quite common considering not all certs have issuer URLs that support it.
|
||||
log.Printf("[ERROR] Getting OCSP for %s: %v", name, err)
|
||||
}
|
||||
cfg.certCache.Lock()
|
||||
cfg.certCache.cache[cert.Hash] = cert
|
||||
cfg.certCache.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
// renewDynamicCertificate renews the certificate for name using cfg. It returns the
|
||||
// certificate to use and an error, if any. name should already be lower-cased before
|
||||
// calling this function. name is the name obtained directly from the handshake's
|
||||
// ClientHello.
|
||||
//
|
||||
// This function is safe for use by multiple concurrent goroutines.
|
||||
func (cfg *Config) renewDynamicCertificate(name string, currentCert Certificate) (Certificate, error) {
|
||||
obtainCertWaitChansMu.Lock()
|
||||
wait, ok := obtainCertWaitChans[name]
|
||||
if ok {
|
||||
// lucky us -- another goroutine is already renewing the certificate.
|
||||
// wait for it to finish, then we'll use the new one.
|
||||
obtainCertWaitChansMu.Unlock()
|
||||
<-wait
|
||||
return cfg.getCertDuringHandshake(name, true, false)
|
||||
}
|
||||
|
||||
// looks like it's up to us to do all the work and renew the cert
|
||||
wait = make(chan struct{})
|
||||
obtainCertWaitChans[name] = wait
|
||||
obtainCertWaitChansMu.Unlock()
|
||||
|
||||
// renew and reload the certificate
|
||||
log.Printf("[INFO] Renewing certificate for %s", name)
|
||||
err := cfg.RenewCert(name, false)
|
||||
if err == nil {
|
||||
// even though the recursive nature of the dynamic cert loading
|
||||
// would just call this function anyway, we do it here to
|
||||
// make the replacement as atomic as possible.
|
||||
newCert, err := currentCert.configs[0].CacheManagedCertificate(name)
|
||||
if err != nil {
|
||||
log.Printf("[ERROR] loading renewed certificate for %s: %v", name, err)
|
||||
} else {
|
||||
// replace the old certificate with the new one
|
||||
err = cfg.certCache.replaceCertificate(currentCert, newCert)
|
||||
if err != nil {
|
||||
log.Printf("[ERROR] Replacing certificate for %s: %v", name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// immediately unblock anyone waiting for it; doing this in
|
||||
// a defer would risk deadlock because of the recursive call
|
||||
// to getCertDuringHandshake below when we return!
|
||||
obtainCertWaitChansMu.Lock()
|
||||
close(wait)
|
||||
delete(obtainCertWaitChans, name)
|
||||
obtainCertWaitChansMu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
return Certificate{}, err
|
||||
}
|
||||
|
||||
return cfg.getCertDuringHandshake(name, true, false)
|
||||
}
|
||||
|
||||
// obtainCertWaitChans is used to coordinate obtaining certs for each hostname.
|
||||
var obtainCertWaitChans = make(map[string]chan struct{})
|
||||
var obtainCertWaitChansMu sync.Mutex
|
||||
|
||||
// failedIssuance is a set of names that we recently failed to get a
|
||||
// certificate for from the ACME CA. They are removed after some time.
|
||||
// When a name is in this map, do not issue a certificate for it on-demand.
|
||||
var failedIssuance = make(map[string]time.Time)
|
||||
var failedIssuanceMu sync.RWMutex
|
||||
|
||||
// lastIssueTime records when we last obtained a certificate successfully.
|
||||
// If this value is recent, do not make any on-demand certificate requests.
|
||||
var lastIssueTime time.Time
|
||||
var lastIssueTimeMu sync.Mutex
|
||||
// ClientHelloTelemetry determines whether to report
|
||||
// TLS ClientHellos to telemetry. Disable if doing
|
||||
// it from a different package.
|
||||
var ClientHelloTelemetry = true
|
||||
|
||||
@@ -1,76 +0,0 @@
|
||||
// Copyright 2015 Light Code Labs, LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package caddytls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetCertificate(t *testing.T) {
|
||||
certCache := &certificateCache{cache: make(map[string]Certificate)}
|
||||
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
|
||||
|
||||
hello := &tls.ClientHelloInfo{ServerName: "example.com"}
|
||||
helloSub := &tls.ClientHelloInfo{ServerName: "sub.example.com"}
|
||||
helloNoSNI := &tls.ClientHelloInfo{}
|
||||
helloNoMatch := &tls.ClientHelloInfo{ServerName: "nomatch"}
|
||||
|
||||
// When cache is empty
|
||||
if cert, err := cfg.GetCertificate(hello); err == nil {
|
||||
t.Errorf("GetCertificate should return error when cache is empty, got: %v", cert)
|
||||
}
|
||||
if cert, err := cfg.GetCertificate(helloNoSNI); err == nil {
|
||||
t.Errorf("GetCertificate should return error when cache is empty even if server name is blank, got: %v", cert)
|
||||
}
|
||||
|
||||
// When cache has one certificate in it
|
||||
firstCert := Certificate{Names: []string{"example.com"}, Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"example.com"}}}}
|
||||
cfg.cacheCertificate(firstCert)
|
||||
if cert, err := cfg.GetCertificate(hello); err != nil {
|
||||
t.Errorf("Got an error but shouldn't have, when cert exists in cache: %v", err)
|
||||
} else if cert.Leaf.DNSNames[0] != "example.com" {
|
||||
t.Errorf("Got wrong certificate with exact match; expected 'example.com', got: %v", cert)
|
||||
}
|
||||
if _, err := cfg.GetCertificate(helloNoSNI); err != nil {
|
||||
t.Errorf("Got an error with no SNI but shouldn't have, when cert exists in cache: %v", err)
|
||||
}
|
||||
|
||||
// When retrieving wildcard certificate
|
||||
wildcardCert := Certificate{
|
||||
Names: []string{"*.example.com"},
|
||||
Certificate: tls.Certificate{Leaf: &x509.Certificate{DNSNames: []string{"*.example.com"}}},
|
||||
Hash: "(don't overwrite the first one)",
|
||||
}
|
||||
cfg.cacheCertificate(wildcardCert)
|
||||
if cert, err := cfg.GetCertificate(helloSub); err != nil {
|
||||
t.Errorf("Didn't get wildcard cert, got: cert=%v, err=%v ", cert, err)
|
||||
} else if cert.Leaf.DNSNames[0] != "*.example.com" {
|
||||
t.Errorf("Got wrong certificate, expected wildcard: %v", cert)
|
||||
}
|
||||
|
||||
// When cache is NOT empty but there's no SNI
|
||||
if cert, err := cfg.GetCertificate(helloNoSNI); err != nil {
|
||||
t.Errorf("Expected random certificate with no error when no SNI, got err: %v", err)
|
||||
} else if cert == nil || len(cert.Leaf.DNSNames) == 0 {
|
||||
t.Errorf("Expected random cert with no matches, got: %v", cert)
|
||||
}
|
||||
|
||||
// When no certificate matches, raise an alert
|
||||
if _, err := cfg.GetCertificate(helloNoMatch); err == nil {
|
||||
t.Errorf("Expected an error when no certificate matched the SNI, got: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -1,72 +0,0 @@
|
||||
// Copyright 2015 Light Code Labs, LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package caddytls
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const challengeBasePath = "/.well-known/acme-challenge"
|
||||
|
||||
// HTTPChallengeHandler proxies challenge requests to ACME client if the
|
||||
// request path starts with challengeBasePath, if the HTTP challenge is not
|
||||
// disabled, and if we are known to be obtaining a certificate for the name.
|
||||
// It returns true if it handled the request and no more needs to be done;
|
||||
// it returns false if this call was a no-op and the request still needs handling.
|
||||
func HTTPChallengeHandler(w http.ResponseWriter, r *http.Request, listenHost string) bool {
|
||||
if !strings.HasPrefix(r.URL.Path, challengeBasePath) {
|
||||
return false
|
||||
}
|
||||
if DisableHTTPChallenge {
|
||||
return false
|
||||
}
|
||||
if !namesObtaining.Has(r.Host) {
|
||||
return false
|
||||
}
|
||||
|
||||
scheme := "http"
|
||||
if r.TLS != nil {
|
||||
scheme = "https"
|
||||
}
|
||||
|
||||
if listenHost == "" {
|
||||
listenHost = "localhost"
|
||||
}
|
||||
|
||||
// always proxy to the DefaultHTTPAlternatePort because obviously the
|
||||
// ACME challenge request already got into one of our HTTP handlers, so
|
||||
// it means we must have started a HTTP listener on the alternate
|
||||
// port instead; which is only accessible via listenHost
|
||||
upstream, err := url.Parse(fmt.Sprintf("%s://%s:%s", scheme, listenHost, DefaultHTTPAlternatePort))
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
log.Printf("[ERROR] ACME proxy handler: %v", err)
|
||||
return true
|
||||
}
|
||||
|
||||
proxy := httputil.NewSingleHostReverseProxy(upstream)
|
||||
proxy.Transport = &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
proxy.ServeHTTP(w, r)
|
||||
|
||||
return true
|
||||
}
|
||||
@@ -1,84 +0,0 @@
|
||||
// Copyright 2015 Light Code Labs, LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package caddytls
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHTTPChallengeHandlerNoOp(t *testing.T) {
|
||||
namesObtaining.Add([]string{"localhost"})
|
||||
|
||||
// try base paths and host names that aren't
|
||||
// handled by this handler
|
||||
for _, url := range []string{
|
||||
"http://localhost/",
|
||||
"http://localhost/foo.html",
|
||||
"http://localhost/.git",
|
||||
"http://localhost/.well-known/",
|
||||
"http://localhost/.well-known/acme-challenging",
|
||||
"http://other/.well-known/acme-challenge/foo",
|
||||
} {
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not craft request, got error: %v", err)
|
||||
}
|
||||
rw := httptest.NewRecorder()
|
||||
if HTTPChallengeHandler(rw, req, "") {
|
||||
t.Errorf("Got true with this URL, but shouldn't have: %s", url)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTPChallengeHandlerSuccess(t *testing.T) {
|
||||
expectedPath := challengeBasePath + "/asdf"
|
||||
|
||||
// Set up fake acme handler backend to make sure proxying succeeds
|
||||
var proxySuccess bool
|
||||
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
proxySuccess = true
|
||||
if r.URL.Path != expectedPath {
|
||||
t.Errorf("Expected path '%s' but got '%s' instead", expectedPath, r.URL.Path)
|
||||
}
|
||||
}))
|
||||
|
||||
// Custom listener that uses the port we expect
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:"+DefaultHTTPAlternatePort)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to start test server listener: %v", err)
|
||||
}
|
||||
ts.Listener = ln
|
||||
|
||||
// Tell this package that we are handling a challenge for 127.0.0.1
|
||||
namesObtaining.Add([]string{"127.0.0.1"})
|
||||
|
||||
// Start our engines and run the test
|
||||
ts.Start()
|
||||
defer ts.Close()
|
||||
req, err := http.NewRequest("GET", "http://127.0.0.1:"+DefaultHTTPAlternatePort+expectedPath, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not craft request, got error: %v", err)
|
||||
}
|
||||
rw := httptest.NewRecorder()
|
||||
|
||||
HTTPChallengeHandler(rw, req, "")
|
||||
|
||||
if !proxySuccess {
|
||||
t.Fatal("Expected request to be proxied, but it wasn't")
|
||||
}
|
||||
}
|
||||
@@ -1,364 +0,0 @@
|
||||
// Copyright 2015 Light Code Labs, LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package caddytls
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
|
||||
"golang.org/x/crypto/ocsp"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// maintain assets while this package is imported, which is
|
||||
// always. we don't ever stop it, since we need it running.
|
||||
go maintainAssets(make(chan struct{}))
|
||||
}
|
||||
|
||||
const (
|
||||
// RenewInterval is how often to check certificates for renewal.
|
||||
RenewInterval = 12 * time.Hour
|
||||
|
||||
// RenewDurationBefore is how long before expiration to renew certificates.
|
||||
RenewDurationBefore = (24 * time.Hour) * 30
|
||||
|
||||
// RenewDurationBeforeAtStartup is how long before expiration to require
|
||||
// a renewed certificate when the process is first starting up (see #1680).
|
||||
// A wider window between RenewDurationBefore and this value will allow
|
||||
// Caddy to start under duress but hopefully this duration will give it
|
||||
// enough time for the blockage to be relieved.
|
||||
RenewDurationBeforeAtStartup = (24 * time.Hour) * 7
|
||||
|
||||
// OCSPInterval is how often to check if OCSP stapling needs updating.
|
||||
OCSPInterval = 1 * time.Hour
|
||||
)
|
||||
|
||||
// maintainAssets is a permanently-blocking function
|
||||
// that loops indefinitely and, on a regular schedule, checks
|
||||
// certificates for expiration and initiates a renewal of certs
|
||||
// that are expiring soon. It also updates OCSP stapling and
|
||||
// performs other maintenance of assets. It should only be
|
||||
// called once per process.
|
||||
//
|
||||
// You must pass in the channel which you'll close when
|
||||
// maintenance should stop, to allow this goroutine to clean up
|
||||
// after itself and unblock. (Not that you HAVE to stop it...)
|
||||
func maintainAssets(stopChan chan struct{}) {
|
||||
renewalTicker := time.NewTicker(RenewInterval)
|
||||
ocspTicker := time.NewTicker(OCSPInterval)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-renewalTicker.C:
|
||||
log.Println("[INFO] Scanning for expiring certificates")
|
||||
RenewManagedCertificates(false)
|
||||
log.Println("[INFO] Done checking certificates")
|
||||
case <-ocspTicker.C:
|
||||
log.Println("[INFO] Scanning for stale OCSP staples")
|
||||
UpdateOCSPStaples()
|
||||
DeleteOldStapleFiles()
|
||||
log.Println("[INFO] Done checking OCSP staples")
|
||||
case <-stopChan:
|
||||
renewalTicker.Stop()
|
||||
ocspTicker.Stop()
|
||||
log.Println("[INFO] Stopped background maintenance routine")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RenewManagedCertificates renews managed certificates,
|
||||
// including ones loaded on-demand.
|
||||
func RenewManagedCertificates(allowPrompts bool) (err error) {
|
||||
for _, inst := range caddy.Instances() {
|
||||
inst.StorageMu.RLock()
|
||||
certCache, ok := inst.Storage[CertCacheInstStorageKey].(*certificateCache)
|
||||
inst.StorageMu.RUnlock()
|
||||
if !ok || certCache == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// we use the queues for a very important reason: to do any and all
|
||||
// operations that could require an exclusive write lock outside
|
||||
// of the read lock! otherwise we get a deadlock, yikes. in other
|
||||
// words, our first iteration through the certificate cache does NOT
|
||||
// perform any operations--only queues them--so that more fine-grained
|
||||
// write locks may be obtained during the actual operations.
|
||||
var renewQueue, reloadQueue, deleteQueue []Certificate
|
||||
|
||||
certCache.RLock()
|
||||
for certKey, cert := range certCache.cache {
|
||||
if len(cert.configs) == 0 {
|
||||
// this is bad if this happens, probably a programmer error (oops)
|
||||
log.Printf("[ERROR] No associated TLS config for certificate with names %v; unable to manage", cert.Names)
|
||||
continue
|
||||
}
|
||||
if !cert.configs[0].Managed || cert.configs[0].SelfSigned {
|
||||
continue
|
||||
}
|
||||
|
||||
// the list of names on this cert should never be empty... programmer error?
|
||||
if cert.Names == nil || len(cert.Names) == 0 {
|
||||
log.Printf("[WARNING] Certificate keyed by '%s' has no names: %v - removing from cache", certKey, cert.Names)
|
||||
deleteQueue = append(deleteQueue, cert)
|
||||
continue
|
||||
}
|
||||
|
||||
// if time is up or expires soon, we need to try to renew it
|
||||
timeLeft := cert.NotAfter.Sub(time.Now().UTC())
|
||||
if timeLeft < RenewDurationBefore {
|
||||
// see if the certificate in storage has already been renewed, possibly by another
|
||||
// instance of Caddy that didn't coordinate with this one; if so, just load it (this
|
||||
// might happen if another instance already renewed it - kinda sloppy but checking disk
|
||||
// first is a simple way to possibly drastically reduce rate limit problems)
|
||||
storedCertExpiring, err := managedCertInStorageExpiresSoon(cert)
|
||||
if err != nil {
|
||||
// hmm, weird, but not a big deal, maybe it was deleted or something
|
||||
log.Printf("[NOTICE] Error while checking if certificate for %v in storage is also expiring soon: %v",
|
||||
cert.Names, err)
|
||||
} else if !storedCertExpiring {
|
||||
// if the certificate is NOT expiring soon and there was no error, then we
|
||||
// are good to just reload the certificate from storage instead of repeating
|
||||
// a likely-unnecessary renewal procedure
|
||||
reloadQueue = append(reloadQueue, cert)
|
||||
continue
|
||||
}
|
||||
|
||||
// the certificate in storage has not been renewed yet, so we will do it
|
||||
// NOTE 1: This is not correct 100% of the time, if multiple Caddy instances
|
||||
// happen to run their maintenance checks at approximately the same times;
|
||||
// both might start renewal at about the same time and do two renewals and one
|
||||
// will overwrite the other. Hence TLS storage plugins. This is sort of a TODO.
|
||||
// NOTE 2: It is super-important to note that the TLS-SNI challenge requires
|
||||
// a write lock on the cache in order to complete its challenge, so it is extra
|
||||
// vital that this renew operation does not happen inside our read lock!
|
||||
renewQueue = append(renewQueue, cert)
|
||||
}
|
||||
}
|
||||
certCache.RUnlock()
|
||||
|
||||
// Reload certificates that merely need to be updated in memory
|
||||
for _, oldCert := range reloadQueue {
|
||||
timeLeft := oldCert.NotAfter.Sub(time.Now().UTC())
|
||||
log.Printf("[INFO] Certificate for %v expires in %v, but is already renewed in storage; reloading stored certificate",
|
||||
oldCert.Names, timeLeft)
|
||||
|
||||
err = certCache.reloadManagedCertificate(oldCert)
|
||||
if err != nil {
|
||||
if allowPrompts {
|
||||
return err // operator is present, so report error immediately
|
||||
}
|
||||
log.Printf("[ERROR] Loading renewed certificate: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Renewal queue
|
||||
for _, oldCert := range renewQueue {
|
||||
timeLeft := oldCert.NotAfter.Sub(time.Now().UTC())
|
||||
log.Printf("[INFO] Certificate for %v expires in %v; attempting renewal", oldCert.Names, timeLeft)
|
||||
|
||||
// Get the name which we should use to renew this certificate;
|
||||
// we only support managing certificates with one name per cert,
|
||||
// so this should be easy. We can't rely on cert.Config.Hostname
|
||||
// because it may be a wildcard value from the Caddyfile (e.g.
|
||||
// *.something.com) which, as of Jan. 2017, is not supported by ACME.
|
||||
// TODO: ^ ^ ^ (wildcards)
|
||||
renewName := oldCert.Names[0]
|
||||
|
||||
// perform renewal
|
||||
err := oldCert.configs[0].RenewCert(renewName, allowPrompts)
|
||||
if err != nil {
|
||||
if allowPrompts {
|
||||
// Certificate renewal failed and the operator is present. See a discussion
|
||||
// about this in issue 642. For a while, we only stopped if the certificate
|
||||
// was expired, but in reality, there is no difference between reporting
|
||||
// it now versus later, except that there's somebody present to deal with
|
||||
// it right now. Follow-up: See issue 1680. Only fail in this case if the
|
||||
// certificate is dangerously close to expiration.
|
||||
timeLeft := oldCert.NotAfter.Sub(time.Now().UTC())
|
||||
if timeLeft < RenewDurationBeforeAtStartup {
|
||||
return err
|
||||
}
|
||||
}
|
||||
log.Printf("[ERROR] %v", err)
|
||||
if oldCert.configs[0].OnDemand {
|
||||
// loaded dynamically, remove dynamically
|
||||
deleteQueue = append(deleteQueue, oldCert)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// successful renewal, so update in-memory cache by loading
|
||||
// renewed certificate so it will be used with handshakes
|
||||
err = certCache.reloadManagedCertificate(oldCert)
|
||||
if err != nil {
|
||||
if allowPrompts {
|
||||
return err // operator is present, so report error immediately
|
||||
}
|
||||
log.Printf("[ERROR] %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Deletion queue
|
||||
for _, cert := range deleteQueue {
|
||||
certCache.Lock()
|
||||
// remove any pointers to this certificate from Configs
|
||||
for _, cfg := range cert.configs {
|
||||
for name, certKey := range cfg.Certificates {
|
||||
if certKey == cert.Hash {
|
||||
delete(cfg.Certificates, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
// then delete the certificate from the cache
|
||||
delete(certCache.cache, cert.Hash)
|
||||
certCache.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateOCSPStaples updates the OCSP stapling in all
|
||||
// eligible, cached certificates.
|
||||
//
|
||||
// OCSP maintenance strives to abide the relevant points on
|
||||
// Ryan Sleevi's recommendations for good OCSP support:
|
||||
// https://gist.github.com/sleevi/5efe9ef98961ecfb4da8
|
||||
func UpdateOCSPStaples() {
|
||||
for _, inst := range caddy.Instances() {
|
||||
inst.StorageMu.RLock()
|
||||
certCache, ok := inst.Storage[CertCacheInstStorageKey].(*certificateCache)
|
||||
inst.StorageMu.RUnlock()
|
||||
if !ok || certCache == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Create a temporary place to store updates
|
||||
// until we release the potentially long-lived
|
||||
// read lock and use a short-lived write lock
|
||||
// on the certificate cache.
|
||||
type ocspUpdate struct {
|
||||
rawBytes []byte
|
||||
parsed *ocsp.Response
|
||||
}
|
||||
updated := make(map[string]ocspUpdate)
|
||||
|
||||
certCache.RLock()
|
||||
for certHash, cert := range certCache.cache {
|
||||
// no point in updating OCSP for expired certificates
|
||||
if time.Now().After(cert.NotAfter) {
|
||||
continue
|
||||
}
|
||||
|
||||
var lastNextUpdate time.Time
|
||||
if cert.OCSP != nil {
|
||||
lastNextUpdate = cert.OCSP.NextUpdate
|
||||
if freshOCSP(cert.OCSP) {
|
||||
continue // no need to update staple if ours is still fresh
|
||||
}
|
||||
}
|
||||
|
||||
err := stapleOCSP(&cert, nil)
|
||||
if err != nil {
|
||||
if cert.OCSP != nil {
|
||||
// if there was no staple before, that's fine; otherwise we should log the error
|
||||
log.Printf("[ERROR] Checking OCSP: %v", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// By this point, we've obtained the latest OCSP response.
|
||||
// If there was no staple before, or if the response is updated, make
|
||||
// sure we apply the update to all names on the certificate.
|
||||
if cert.OCSP != nil && (lastNextUpdate.IsZero() || lastNextUpdate != cert.OCSP.NextUpdate) {
|
||||
log.Printf("[INFO] Advancing OCSP staple for %v from %s to %s",
|
||||
cert.Names, lastNextUpdate, cert.OCSP.NextUpdate)
|
||||
updated[certHash] = ocspUpdate{rawBytes: cert.Certificate.OCSPStaple, parsed: cert.OCSP}
|
||||
}
|
||||
}
|
||||
certCache.RUnlock()
|
||||
|
||||
// These write locks should be brief since we have all the info we need now.
|
||||
for certKey, update := range updated {
|
||||
certCache.Lock()
|
||||
cert := certCache.cache[certKey]
|
||||
cert.OCSP = update.parsed
|
||||
cert.Certificate.OCSPStaple = update.rawBytes
|
||||
certCache.cache[certKey] = cert
|
||||
certCache.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// DeleteOldStapleFiles deletes cached OCSP staples that have expired.
|
||||
// TODO: Should we do this for certificates too?
|
||||
func DeleteOldStapleFiles() {
|
||||
// TODO: Upgrade caddytls.Storage to support OCSP operations too
|
||||
files, err := ioutil.ReadDir(ocspFolder)
|
||||
if err != nil {
|
||||
// maybe just hasn't been created yet; no big deal
|
||||
return
|
||||
}
|
||||
for _, file := range files {
|
||||
if file.IsDir() {
|
||||
// weird, what's a folder doing inside the OCSP cache?
|
||||
continue
|
||||
}
|
||||
stapleFile := filepath.Join(ocspFolder, file.Name())
|
||||
ocspBytes, err := ioutil.ReadFile(stapleFile)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
resp, err := ocsp.ParseResponse(ocspBytes, nil)
|
||||
if err != nil {
|
||||
// contents are invalid; delete it
|
||||
err = os.Remove(stapleFile)
|
||||
if err != nil {
|
||||
log.Printf("[ERROR] Purging corrupt staple file %s: %v", stapleFile, err)
|
||||
}
|
||||
}
|
||||
if time.Now().After(resp.NextUpdate) {
|
||||
// response has expired; delete it
|
||||
err = os.Remove(stapleFile)
|
||||
if err != nil {
|
||||
log.Printf("[ERROR] Purging expired staple file %s: %v", stapleFile, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// freshOCSP returns true if resp is still fresh,
|
||||
// meaning that it is not expedient to get an
|
||||
// updated response from the OCSP server.
|
||||
func freshOCSP(resp *ocsp.Response) bool {
|
||||
nextUpdate := resp.NextUpdate
|
||||
// If there is an OCSP responder certificate, and it expires before the
|
||||
// OCSP response, use its expiration date as the end of the OCSP
|
||||
// response's validity period.
|
||||
if resp.Certificate != nil && resp.Certificate.NotAfter.Before(nextUpdate) {
|
||||
nextUpdate = resp.Certificate.NotAfter
|
||||
}
|
||||
// start checking OCSP staple about halfway through validity period for good measure
|
||||
refreshTime := resp.ThisUpdate.Add(nextUpdate.Sub(resp.ThisUpdate) / 2)
|
||||
return time.Now().Before(refreshTime)
|
||||
}
|
||||
|
||||
var ocspFolder = filepath.Join(caddy.AssetsPath(), "ocsp")
|
||||
@@ -0,0 +1,106 @@
|
||||
package caddytls
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/xenolf/lego/certcrypto"
|
||||
)
|
||||
|
||||
// newSelfSignedCertificate returns a new self-signed certificate.
|
||||
func newSelfSignedCertificate(ssconfig selfSignedConfig) (tls.Certificate, error) {
|
||||
// start by generating private key
|
||||
var privKey interface{}
|
||||
var err error
|
||||
switch ssconfig.KeyType {
|
||||
case "", certcrypto.EC256:
|
||||
privKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
case certcrypto.EC384:
|
||||
privKey, err = ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
|
||||
case certcrypto.RSA2048:
|
||||
privKey, err = rsa.GenerateKey(rand.Reader, 2048)
|
||||
case certcrypto.RSA4096:
|
||||
privKey, err = rsa.GenerateKey(rand.Reader, 4096)
|
||||
case certcrypto.RSA8192:
|
||||
privKey, err = rsa.GenerateKey(rand.Reader, 8192)
|
||||
default:
|
||||
return tls.Certificate{}, fmt.Errorf("cannot generate private key; unknown key type %v", ssconfig.KeyType)
|
||||
}
|
||||
if err != nil {
|
||||
return tls.Certificate{}, fmt.Errorf("failed to generate private key: %v", err)
|
||||
}
|
||||
|
||||
// create certificate structure with proper values
|
||||
notBefore := time.Now()
|
||||
notAfter := ssconfig.Expire
|
||||
if notAfter.IsZero() || notAfter.Before(notBefore) {
|
||||
notAfter = notBefore.Add(24 * time.Hour * 7)
|
||||
}
|
||||
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
||||
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, fmt.Errorf("failed to generate serial number: %v", err)
|
||||
}
|
||||
cert := &x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{Organization: []string{"Caddy Self-Signed"}},
|
||||
NotBefore: notBefore,
|
||||
NotAfter: notAfter,
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||
}
|
||||
if len(ssconfig.SAN) == 0 {
|
||||
ssconfig.SAN = []string{""}
|
||||
}
|
||||
var names []string
|
||||
for _, san := range ssconfig.SAN {
|
||||
if ip := net.ParseIP(san); ip != nil {
|
||||
names = append(names, strings.ToLower(ip.String()))
|
||||
cert.IPAddresses = append(cert.IPAddresses, ip)
|
||||
} else {
|
||||
names = append(names, strings.ToLower(san))
|
||||
cert.DNSNames = append(cert.DNSNames, strings.ToLower(san))
|
||||
}
|
||||
}
|
||||
|
||||
// generate the associated public key
|
||||
publicKey := func(privKey interface{}) interface{} {
|
||||
switch k := privKey.(type) {
|
||||
case *rsa.PrivateKey:
|
||||
return &k.PublicKey
|
||||
case *ecdsa.PrivateKey:
|
||||
return &k.PublicKey
|
||||
default:
|
||||
return fmt.Errorf("unknown key type")
|
||||
}
|
||||
}
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, cert, cert, publicKey(privKey), privKey)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, fmt.Errorf("could not create certificate: %v", err)
|
||||
}
|
||||
|
||||
chain := [][]byte{derBytes}
|
||||
|
||||
return tls.Certificate{
|
||||
Certificate: chain,
|
||||
PrivateKey: privKey,
|
||||
Leaf: cert,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// selfSignedConfig configures a self-signed certificate.
|
||||
type selfSignedConfig struct {
|
||||
SAN []string
|
||||
KeyType certcrypto.KeyType
|
||||
Expire time.Time
|
||||
}
|
||||
+130
-50
@@ -28,17 +28,21 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/mholt/caddy/telemetry"
|
||||
"github.com/mholt/certmagic"
|
||||
)
|
||||
|
||||
func init() {
|
||||
caddy.RegisterPlugin("tls", caddy.Plugin{Action: setupTLS})
|
||||
|
||||
// ensure the default Storage implementation is plugged in
|
||||
caddy.RegisterClusterPlugin("file", constructDefaultClusterPlugin)
|
||||
}
|
||||
|
||||
// setupTLS sets up the TLS configuration and installs certificates that
|
||||
// are specified by the user in the config file. All the automatic HTTPS
|
||||
// stuff comes later outside of this function.
|
||||
func setupTLS(c *caddy.Controller) error {
|
||||
// obtain the configGetter, which loads the config we're, uh, configuring
|
||||
configGetter, ok := configGetters[c.ServerType()]
|
||||
if !ok {
|
||||
return fmt.Errorf("no caddytls.ConfigGetter for %s server type; must call RegisterConfigGetter", c.ServerType())
|
||||
@@ -48,18 +52,68 @@ func setupTLS(c *caddy.Controller) error {
|
||||
return fmt.Errorf("no caddytls.Config to set up for %s", c.Key)
|
||||
}
|
||||
|
||||
// the certificate cache is tied to the current caddy.Instance; get a pointer to it
|
||||
certCache, ok := c.Get(CertCacheInstStorageKey).(*certificateCache)
|
||||
config.Enabled = true
|
||||
|
||||
// a single certificate cache is used by the whole caddy.Instance; get a pointer to it
|
||||
certCache, ok := c.Get(CertCacheInstStorageKey).(*certmagic.Cache)
|
||||
if !ok || certCache == nil {
|
||||
certCache = &certificateCache{cache: make(map[string]Certificate)}
|
||||
certCache = certmagic.NewCache(certmagic.DefaultStorage)
|
||||
c.OnShutdown(func() error {
|
||||
certCache.Stop()
|
||||
return nil
|
||||
})
|
||||
c.Set(CertCacheInstStorageKey, certCache)
|
||||
}
|
||||
config.certCache = certCache
|
||||
config.Manager = certmagic.NewWithCache(certCache, certmagic.Config{})
|
||||
|
||||
config.Enabled = true
|
||||
// we use certmagic events to collect metrics for telemetry
|
||||
config.Manager.OnEvent = func(event string, data interface{}) {
|
||||
switch event {
|
||||
case "tls_handshake_started":
|
||||
clientHello := data.(*tls.ClientHelloInfo)
|
||||
if ClientHelloTelemetry && len(clientHello.SupportedVersions) > 0 {
|
||||
// If no other plugin (such as the HTTP server type) is implementing ClientHello telemetry, we do it.
|
||||
// NOTE: The values in the Go standard lib's ClientHelloInfo aren't guaranteed to be in order.
|
||||
info := ClientHelloInfo{
|
||||
Version: clientHello.SupportedVersions[0], // report the highest
|
||||
CipherSuites: clientHello.CipherSuites,
|
||||
ExtensionsUnknown: true, // no extension info... :(
|
||||
CompressionMethodsUnknown: true, // no compression methods... :(
|
||||
Curves: clientHello.SupportedCurves,
|
||||
Points: clientHello.SupportedPoints,
|
||||
// We also have, but do not yet use: SignatureSchemes, ServerName, and SupportedProtos (ALPN)
|
||||
// because the standard lib parses some extensions, but our MITM detector generally doesn't.
|
||||
}
|
||||
go telemetry.SetNested("tls_client_hello", info.Key(), info)
|
||||
}
|
||||
|
||||
case "tls_handshake_completed":
|
||||
// TODO: This is a "best guess" for now - at this point, we only gave a
|
||||
// certificate to the client; we need something listener-level to be sure
|
||||
go telemetry.Increment("tls_handshake_count")
|
||||
|
||||
case "acme_cert_obtained":
|
||||
go telemetry.Increment("tls_acme_certs_obtained")
|
||||
|
||||
case "acme_cert_renewed":
|
||||
name := data.(string)
|
||||
caddy.EmitEvent(caddy.CertRenewEvent, name)
|
||||
go telemetry.Increment("tls_acme_certs_renewed")
|
||||
|
||||
case "acme_cert_revoked":
|
||||
telemetry.Increment("acme_certs_revoked")
|
||||
|
||||
case "cached_managed_cert":
|
||||
telemetry.Increment("tls_managed_cert_count")
|
||||
|
||||
case "cached_unmanaged_cert":
|
||||
telemetry.Increment("tls_unmanaged_cert_count")
|
||||
}
|
||||
}
|
||||
|
||||
for c.Next() {
|
||||
var certificateFile, keyFile, loadDir, maxCerts, askURL string
|
||||
var onDemand bool
|
||||
|
||||
args := c.RemainingArgs()
|
||||
switch len(args) {
|
||||
@@ -95,30 +149,29 @@ func setupTLS(c *caddy.Controller) error {
|
||||
if len(arg) != 1 {
|
||||
return c.ArgErr()
|
||||
}
|
||||
config.CAUrl = arg[0]
|
||||
config.Manager.CA = arg[0]
|
||||
case "key_type":
|
||||
arg := c.RemainingArgs()
|
||||
value, ok := supportedKeyTypes[strings.ToUpper(arg[0])]
|
||||
if !ok {
|
||||
return c.Errf("Wrong key type name or key type not supported: '%s'", c.Val())
|
||||
}
|
||||
config.KeyType = value
|
||||
config.Manager.KeyType = value
|
||||
case "protocols":
|
||||
args := c.RemainingArgs()
|
||||
if len(args) == 1 {
|
||||
value, ok := supportedProtocols[strings.ToLower(args[0])]
|
||||
value, ok := SupportedProtocols[strings.ToLower(args[0])]
|
||||
if !ok {
|
||||
return c.Errf("Wrong protocol name or protocol not supported: '%s'", args[0])
|
||||
}
|
||||
|
||||
config.ProtocolMinVersion, config.ProtocolMaxVersion = value, value
|
||||
} else {
|
||||
value, ok := supportedProtocols[strings.ToLower(args[0])]
|
||||
value, ok := SupportedProtocols[strings.ToLower(args[0])]
|
||||
if !ok {
|
||||
return c.Errf("Wrong protocol name or protocol not supported: '%s'", args[0])
|
||||
}
|
||||
config.ProtocolMinVersion = value
|
||||
value, ok = supportedProtocols[strings.ToLower(args[1])]
|
||||
value, ok = SupportedProtocols[strings.ToLower(args[1])]
|
||||
if !ok {
|
||||
return c.Errf("Wrong protocol name or protocol not supported: '%s'", args[1])
|
||||
}
|
||||
@@ -129,7 +182,7 @@ func setupTLS(c *caddy.Controller) error {
|
||||
}
|
||||
case "ciphers":
|
||||
for c.NextArg() {
|
||||
value, ok := supportedCiphersMap[strings.ToUpper(c.Val())]
|
||||
value, ok := SupportedCiphersMap[strings.ToUpper(c.Val())]
|
||||
if !ok {
|
||||
return c.Errf("Wrong cipher name or cipher not supported: '%s'", c.Val())
|
||||
}
|
||||
@@ -173,30 +226,32 @@ func setupTLS(c *caddy.Controller) error {
|
||||
config.Manual = true
|
||||
case "max_certs":
|
||||
c.Args(&maxCerts)
|
||||
config.OnDemand = true
|
||||
onDemand = true
|
||||
case "ask":
|
||||
c.Args(&askURL)
|
||||
config.OnDemand = true
|
||||
onDemand = true
|
||||
case "dns":
|
||||
args := c.RemainingArgs()
|
||||
if len(args) != 1 {
|
||||
return c.ArgErr()
|
||||
}
|
||||
// TODO: we can get rid of DNS provider plugins with this one line
|
||||
// of code; however, currently (Dec. 2018) this adds about 20 MB
|
||||
// of bloat to the Caddy binary, doubling its size to ~40 MB...!
|
||||
// dnsProv, err := dns.NewDNSChallengeProviderByName(args[0])
|
||||
// if err != nil {
|
||||
// return c.Errf("Configuring DNS provider '%s': %v", args[0], err)
|
||||
// }
|
||||
dnsProvName := args[0]
|
||||
if _, ok := dnsProviders[dnsProvName]; !ok {
|
||||
return c.Errf("Unsupported DNS provider '%s'", args[0])
|
||||
dnsProvConstructor, ok := dnsProviders[dnsProvName]
|
||||
if !ok {
|
||||
return c.Errf("Unknown DNS provider by name '%s'", dnsProvName)
|
||||
}
|
||||
config.DNSProvider = args[0]
|
||||
case "storage":
|
||||
args := c.RemainingArgs()
|
||||
if len(args) != 1 {
|
||||
return c.ArgErr()
|
||||
dnsProv, err := dnsProvConstructor()
|
||||
if err != nil {
|
||||
return c.Errf("Setting up DNS provider '%s': %v", dnsProvName, err)
|
||||
}
|
||||
storageProvName := args[0]
|
||||
if _, ok := storageProviders[storageProvName]; !ok {
|
||||
return c.Errf("Unsupported Storage provider '%s'", args[0])
|
||||
}
|
||||
config.StorageProvider = args[0]
|
||||
config.Manager.DNSProvider = dnsProv
|
||||
case "alpn":
|
||||
args := c.RemainingArgs()
|
||||
if len(args) == 0 {
|
||||
@@ -206,9 +261,22 @@ func setupTLS(c *caddy.Controller) error {
|
||||
config.ALPN = append(config.ALPN, arg)
|
||||
}
|
||||
case "must_staple":
|
||||
config.MustStaple = true
|
||||
config.Manager.MustStaple = true
|
||||
case "wildcard":
|
||||
if !certmagic.HostQualifies(config.Hostname) {
|
||||
return c.Errf("Hostname '%s' does not qualify for managed TLS, so cannot manage wildcard certificate for it", config.Hostname)
|
||||
}
|
||||
if strings.Contains(config.Hostname, "*") {
|
||||
return c.Errf("Cannot convert domain name '%s' to a valid wildcard: already has a wildcard label", config.Hostname)
|
||||
}
|
||||
parts := strings.Split(config.Hostname, ".")
|
||||
if len(parts) < 3 {
|
||||
return c.Errf("Cannot convert domain name '%s' to a valid wildcard: too few labels", config.Hostname)
|
||||
}
|
||||
parts[0] = "*"
|
||||
config.Hostname = strings.Join(parts, ".")
|
||||
default:
|
||||
return c.Errf("Unknown keyword '%s'", c.Val())
|
||||
return c.Errf("Unknown subdirective '%s'", c.Val())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -217,26 +285,26 @@ func setupTLS(c *caddy.Controller) error {
|
||||
return c.ArgErr()
|
||||
}
|
||||
|
||||
// set certificate limit if on-demand TLS is enabled
|
||||
if maxCerts != "" {
|
||||
maxCertsNum, err := strconv.Atoi(maxCerts)
|
||||
if err != nil || maxCertsNum < 1 {
|
||||
return c.Err("max_certs must be a positive integer")
|
||||
// configure on-demand TLS, if enabled
|
||||
if onDemand {
|
||||
config.Manager.OnDemand = new(certmagic.OnDemandConfig)
|
||||
if maxCerts != "" {
|
||||
maxCertsNum, err := strconv.Atoi(maxCerts)
|
||||
if err != nil || maxCertsNum < 1 {
|
||||
return c.Err("max_certs must be a positive integer")
|
||||
}
|
||||
config.Manager.OnDemand.MaxObtain = int32(maxCertsNum)
|
||||
}
|
||||
config.OnDemandState.MaxObtain = int32(maxCertsNum)
|
||||
}
|
||||
|
||||
if askURL != "" {
|
||||
parsedURL, err := url.Parse(askURL)
|
||||
if err != nil {
|
||||
return c.Err("ask must be a valid url")
|
||||
if askURL != "" {
|
||||
parsedURL, err := url.Parse(askURL)
|
||||
if err != nil {
|
||||
return c.Err("ask must be a valid url")
|
||||
}
|
||||
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
|
||||
return c.Err("ask URL must use http or https")
|
||||
}
|
||||
config.Manager.OnDemand.AskURL = parsedURL
|
||||
}
|
||||
|
||||
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
|
||||
return c.Err("ask URL must use http or https")
|
||||
}
|
||||
|
||||
config.OnDemandState.AskURL = parsedURL
|
||||
}
|
||||
|
||||
// don't try to load certificates unless we're supposed to
|
||||
@@ -246,7 +314,7 @@ func setupTLS(c *caddy.Controller) error {
|
||||
|
||||
// load a single certificate and key, if specified
|
||||
if certificateFile != "" && keyFile != "" {
|
||||
err := config.cacheUnmanagedCertificatePEMFile(certificateFile, keyFile)
|
||||
err := config.Manager.CacheUnmanagedCertificatePEMFile(certificateFile, keyFile)
|
||||
if err != nil {
|
||||
return c.Errf("Unable to load certificate and key files for '%s': %v", c.Key, err)
|
||||
}
|
||||
@@ -266,10 +334,18 @@ func setupTLS(c *caddy.Controller) error {
|
||||
|
||||
// generate self-signed cert if needed
|
||||
if config.SelfSigned {
|
||||
err := makeSelfSignedCert(config)
|
||||
ssCert, err := newSelfSignedCertificate(selfSignedConfig{
|
||||
SAN: []string{config.Hostname},
|
||||
KeyType: config.Manager.KeyType,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("self-signed certificate generation: %v", err)
|
||||
}
|
||||
err = config.Manager.CacheUnmanagedTLSCertificate(ssCert)
|
||||
if err != nil {
|
||||
return fmt.Errorf("self-signed: %v", err)
|
||||
}
|
||||
telemetry.Increment("tls_self_signed_count")
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -345,7 +421,7 @@ func loadCertsInDir(cfg *Config, c *caddy.Controller, dir string) error {
|
||||
return c.Errf("%s: no private key block found", path)
|
||||
}
|
||||
|
||||
err = cfg.cacheUnmanagedCertificatePEMBytes(certPEMBytes, keyPEMBytes)
|
||||
err = cfg.Manager.CacheUnmanagedCertificatePEMBytes(certPEMBytes, keyPEMBytes)
|
||||
if err != nil {
|
||||
return c.Errf("%s: failed to load cert and key for '%s': %v", path, c.Key, err)
|
||||
}
|
||||
@@ -354,3 +430,7 @@ func loadCertsInDir(cfg *Config, c *caddy.Controller, dir string) error {
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func constructDefaultClusterPlugin() (certmagic.Storage, error) {
|
||||
return &certmagic.FileStorage{Path: caddy.AssetsPath()}, nil
|
||||
}
|
||||
|
||||
+25
-43
@@ -22,7 +22,8 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/xenolf/lego/acme"
|
||||
"github.com/mholt/certmagic"
|
||||
"github.com/xenolf/lego/certcrypto"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
@@ -46,12 +47,9 @@ func TestMain(m *testing.M) {
|
||||
}
|
||||
|
||||
func TestSetupParseBasic(t *testing.T) {
|
||||
certCache := &certificateCache{cache: make(map[string]Certificate)}
|
||||
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
|
||||
|
||||
cfg := &Config{Manager: &certmagic.Config{}}
|
||||
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
||||
c := caddy.NewTestController("", `tls `+certFile+` `+keyFile+``)
|
||||
c.Set(CertCacheInstStorageKey, certCache)
|
||||
|
||||
err := setupTLS(c)
|
||||
if err != nil {
|
||||
@@ -67,8 +65,8 @@ func TestSetupParseBasic(t *testing.T) {
|
||||
}
|
||||
|
||||
// Security defaults
|
||||
if cfg.ProtocolMinVersion != tls.VersionTLS11 {
|
||||
t.Errorf("Expected 'tls1.1 (0x0302)' as ProtocolMinVersion, got %#v", cfg.ProtocolMinVersion)
|
||||
if cfg.ProtocolMinVersion != tls.VersionTLS12 {
|
||||
t.Errorf("Expected 'tls1.2 (0x0303)' as ProtocolMinVersion, got %#v", cfg.ProtocolMinVersion)
|
||||
}
|
||||
if cfg.ProtocolMaxVersion != tls.VersionTLS12 {
|
||||
t.Errorf("Expected 'tls1.2 (0x0303)' as ProtocolMaxVersion, got %v", cfg.ProtocolMaxVersion)
|
||||
@@ -127,12 +125,10 @@ func TestSetupParseWithOptionalParams(t *testing.T) {
|
||||
must_staple
|
||||
alpn http/1.1
|
||||
}`
|
||||
certCache := &certificateCache{cache: make(map[string]Certificate)}
|
||||
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
|
||||
|
||||
cfg := &Config{Manager: &certmagic.Config{}}
|
||||
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
||||
c := caddy.NewTestController("", params)
|
||||
c.Set(CertCacheInstStorageKey, certCache)
|
||||
|
||||
err := setupTLS(c)
|
||||
if err != nil {
|
||||
@@ -151,7 +147,7 @@ func TestSetupParseWithOptionalParams(t *testing.T) {
|
||||
t.Errorf("Expected 3 Ciphers (not including TLS_FALLBACK_SCSV), got %v", len(cfg.Ciphers)-1)
|
||||
}
|
||||
|
||||
if !cfg.MustStaple {
|
||||
if !cfg.Manager.MustStaple {
|
||||
t.Error("Expected must staple to be true")
|
||||
}
|
||||
|
||||
@@ -164,11 +160,9 @@ func TestSetupDefaultWithOptionalParams(t *testing.T) {
|
||||
params := `tls {
|
||||
ciphers RSA-3DES-EDE-CBC-SHA
|
||||
}`
|
||||
certCache := &certificateCache{cache: make(map[string]Certificate)}
|
||||
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
|
||||
cfg := &Config{Manager: &certmagic.Config{}}
|
||||
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
||||
c := caddy.NewTestController("", params)
|
||||
c.Set(CertCacheInstStorageKey, certCache)
|
||||
|
||||
err := setupTLS(c)
|
||||
if err != nil {
|
||||
@@ -184,11 +178,9 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) {
|
||||
params := `tls ` + certFile + ` ` + keyFile + ` {
|
||||
protocols ssl tls
|
||||
}`
|
||||
certCache := &certificateCache{cache: make(map[string]Certificate)}
|
||||
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
|
||||
cfg := &Config{Manager: &certmagic.Config{}}
|
||||
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
||||
c := caddy.NewTestController("", params)
|
||||
c.Set(CertCacheInstStorageKey, certCache)
|
||||
|
||||
err := setupTLS(c)
|
||||
if err == nil {
|
||||
@@ -199,10 +191,9 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) {
|
||||
params = `tls ` + certFile + ` ` + keyFile + ` {
|
||||
ciphers not-valid-cipher
|
||||
}`
|
||||
cfg = new(Config)
|
||||
cfg = &Config{Manager: &certmagic.Config{}}
|
||||
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
||||
c = caddy.NewTestController("", params)
|
||||
c.Set(CertCacheInstStorageKey, certCache)
|
||||
err = setupTLS(c)
|
||||
if err == nil {
|
||||
t.Error("Expected errors, but no error returned")
|
||||
@@ -212,7 +203,7 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) {
|
||||
params = `tls {
|
||||
key_type ab123
|
||||
}`
|
||||
cfg = new(Config)
|
||||
cfg = &Config{Manager: &certmagic.Config{}}
|
||||
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
||||
c = caddy.NewTestController("", params)
|
||||
err = setupTLS(c)
|
||||
@@ -224,10 +215,9 @@ func TestSetupParseWithWrongOptionalParams(t *testing.T) {
|
||||
params = `tls {
|
||||
curves ab123, cd456, ef789
|
||||
}`
|
||||
cfg = new(Config)
|
||||
cfg = &Config{Manager: &certmagic.Config{}}
|
||||
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
||||
c = caddy.NewTestController("", params)
|
||||
c.Set(CertCacheInstStorageKey, certCache)
|
||||
err = setupTLS(c)
|
||||
if err == nil {
|
||||
t.Error("Expected errors, but no error returned")
|
||||
@@ -239,8 +229,7 @@ func TestSetupParseWithClientAuth(t *testing.T) {
|
||||
params := `tls ` + certFile + ` ` + keyFile + ` {
|
||||
clients
|
||||
}`
|
||||
certCache := &certificateCache{cache: make(map[string]Certificate)}
|
||||
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
|
||||
cfg := &Config{Manager: &certmagic.Config{}}
|
||||
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
||||
c := caddy.NewTestController("", params)
|
||||
err := setupTLS(c)
|
||||
@@ -273,11 +262,11 @@ func TestSetupParseWithClientAuth(t *testing.T) {
|
||||
clients verify_if_given
|
||||
}`, tls.VerifyClientCertIfGiven, true, noCAs},
|
||||
} {
|
||||
certCache := &certificateCache{cache: make(map[string]Certificate)}
|
||||
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
|
||||
certCache := certmagic.NewCache(certmagic.DefaultStorage)
|
||||
cfg := &Config{Manager: certmagic.NewWithCache(certCache, certmagic.Config{})}
|
||||
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
||||
c := caddy.NewTestController("", caseData.params)
|
||||
c.Set(CertCacheInstStorageKey, certCache)
|
||||
|
||||
err := setupTLS(c)
|
||||
if caseData.expectedErr {
|
||||
if err == nil {
|
||||
@@ -327,11 +316,10 @@ func TestSetupParseWithCAUrl(t *testing.T) {
|
||||
ca 1 2
|
||||
}`, true, ""},
|
||||
} {
|
||||
certCache := &certificateCache{cache: make(map[string]Certificate)}
|
||||
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
|
||||
cfg := &Config{Manager: &certmagic.Config{}}
|
||||
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
||||
c := caddy.NewTestController("", caseData.params)
|
||||
c.Set(CertCacheInstStorageKey, certCache)
|
||||
|
||||
err := setupTLS(c)
|
||||
if caseData.expectedErr {
|
||||
if err == nil {
|
||||
@@ -343,8 +331,8 @@ func TestSetupParseWithCAUrl(t *testing.T) {
|
||||
t.Errorf("In case %d: Expected no errors, got: %v", caseNumber, err)
|
||||
}
|
||||
|
||||
if cfg.CAUrl != caseData.expectedCAUrl {
|
||||
t.Errorf("Expected '%v' as CAUrl, got %#v", caseData.expectedCAUrl, cfg.CAUrl)
|
||||
if cfg.Manager.CA != caseData.expectedCAUrl {
|
||||
t.Errorf("Expected '%v' as CAUrl, got %#v", caseData.expectedCAUrl, cfg.Manager.CA)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -353,19 +341,17 @@ func TestSetupParseWithKeyType(t *testing.T) {
|
||||
params := `tls {
|
||||
key_type p384
|
||||
}`
|
||||
certCache := &certificateCache{cache: make(map[string]Certificate)}
|
||||
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
|
||||
cfg := &Config{Manager: &certmagic.Config{}}
|
||||
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
||||
c := caddy.NewTestController("", params)
|
||||
c.Set(CertCacheInstStorageKey, certCache)
|
||||
|
||||
err := setupTLS(c)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no errors, got: %v", err)
|
||||
}
|
||||
|
||||
if cfg.KeyType != acme.EC384 {
|
||||
t.Errorf("Expected 'P384' as KeyType, got %#v", cfg.KeyType)
|
||||
if cfg.Manager.KeyType != certcrypto.EC384 {
|
||||
t.Errorf("Expected 'P384' as KeyType, got %#v", cfg.Manager.KeyType)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -373,11 +359,9 @@ func TestSetupParseWithCurves(t *testing.T) {
|
||||
params := `tls {
|
||||
curves x25519 p256 p384 p521
|
||||
}`
|
||||
certCache := &certificateCache{cache: make(map[string]Certificate)}
|
||||
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
|
||||
cfg := &Config{Manager: &certmagic.Config{}}
|
||||
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
||||
c := caddy.NewTestController("", params)
|
||||
c.Set(CertCacheInstStorageKey, certCache)
|
||||
|
||||
err := setupTLS(c)
|
||||
if err != nil {
|
||||
@@ -402,11 +386,9 @@ func TestSetupParseWithOneTLSProtocol(t *testing.T) {
|
||||
params := `tls {
|
||||
protocols tls1.2
|
||||
}`
|
||||
certCache := &certificateCache{cache: make(map[string]Certificate)}
|
||||
cfg := &Config{Certificates: make(map[string]string), certCache: certCache}
|
||||
cfg := &Config{Manager: &certmagic.Config{}}
|
||||
RegisterConfigGetter("", func(c *caddy.Controller) *Config { return cfg })
|
||||
c := caddy.NewTestController("", params)
|
||||
c.Set(CertCacheInstStorageKey, certCache)
|
||||
|
||||
err := setupTLS(c)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,126 +0,0 @@
|
||||
// Copyright 2015 Light Code Labs, LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package caddytls
|
||||
|
||||
import "net/url"
|
||||
|
||||
// StorageConstructor is a function type that is used in the Config to
|
||||
// instantiate a new Storage instance. This function can return a nil
|
||||
// Storage even without an error.
|
||||
type StorageConstructor func(caURL *url.URL) (Storage, error)
|
||||
|
||||
// SiteData contains persisted items pertaining to an individual site.
|
||||
type SiteData struct {
|
||||
// Cert is the public cert byte array.
|
||||
Cert []byte
|
||||
// Key is the private key byte array.
|
||||
Key []byte
|
||||
// Meta is metadata about the site used by Caddy.
|
||||
Meta []byte
|
||||
}
|
||||
|
||||
// UserData contains persisted items pertaining to a user.
|
||||
type UserData struct {
|
||||
// Reg is the user registration byte array.
|
||||
Reg []byte
|
||||
// Key is the user key byte array.
|
||||
Key []byte
|
||||
}
|
||||
|
||||
// Locker provides support for mutual exclusion
|
||||
type Locker interface {
|
||||
// TryLock will return immediatedly with or without acquiring the lock.
|
||||
// If a lock could be obtained, (nil, nil) is returned and you may
|
||||
// continue normally. If not (meaning another process is already
|
||||
// working on that name), a Waiter value will be returned upon
|
||||
// which you can Wait() until it is finished, and then return
|
||||
// when it unblocks. If waiting, do not unlock!
|
||||
//
|
||||
// To prevent deadlocks, all implementations (where this concern
|
||||
// is relevant) should put a reasonable expiration on the lock in
|
||||
// case Unlock is unable to be called due to some sort of storage
|
||||
// system failure or crash.
|
||||
TryLock(name string) (Waiter, error)
|
||||
|
||||
// Unlock unlocks the mutex for name. Only callers of TryLock who
|
||||
// successfully obtained the lock (no Waiter value was returned)
|
||||
// should call this method, and it should be called only after
|
||||
// the obtain/renew and store are finished, even if there was
|
||||
// an error (or a timeout).
|
||||
Unlock(name string) error
|
||||
}
|
||||
|
||||
// Storage is an interface abstracting all storage used by Caddy's TLS
|
||||
// subsystem. Implementations of this interface store both site and
|
||||
// user data.
|
||||
type Storage interface {
|
||||
// SiteExists returns true if this site exists in storage.
|
||||
// Site data is considered present when StoreSite has been called
|
||||
// successfully (without DeleteSite having been called, of course).
|
||||
SiteExists(domain string) (bool, error)
|
||||
|
||||
// LoadSite obtains the site data from storage for the given domain and
|
||||
// returns it. If data for the domain does not exist, an error value
|
||||
// of type ErrNotExist is returned. For multi-server storage, care
|
||||
// should be taken to make this load atomic to prevent race conditions
|
||||
// that happen with multiple data loads.
|
||||
LoadSite(domain string) (*SiteData, error)
|
||||
|
||||
// StoreSite persists the given site data for the given domain in
|
||||
// storage. For multi-server storage, care should be taken to make this
|
||||
// call atomic to prevent half-written data on failure of an internal
|
||||
// intermediate storage step. Implementers can trust that at runtime
|
||||
// this function will only be invoked after LockRegister and before
|
||||
// UnlockRegister of the same domain.
|
||||
StoreSite(domain string, data *SiteData) error
|
||||
|
||||
// DeleteSite deletes the site for the given domain from storage.
|
||||
// Multi-server implementations should attempt to make this atomic. If
|
||||
// the site does not exist, an error value of type ErrNotExist is returned.
|
||||
DeleteSite(domain string) error
|
||||
|
||||
// LoadUser obtains user data from storage for the given email and
|
||||
// returns it. If data for the email does not exist, an error value
|
||||
// of type ErrNotExist is returned. Multi-server implementations
|
||||
// should take care to make this operation atomic for all loaded
|
||||
// data items.
|
||||
LoadUser(email string) (*UserData, error)
|
||||
|
||||
// StoreUser persists the given user data for the given email in
|
||||
// storage. Multi-server implementations should take care to make this
|
||||
// operation atomic for all stored data items.
|
||||
StoreUser(email string, data *UserData) error
|
||||
|
||||
// MostRecentUserEmail provides the most recently used email parameter
|
||||
// in StoreUser. The result is an empty string if there are no
|
||||
// persisted users in storage.
|
||||
MostRecentUserEmail() string
|
||||
|
||||
// Locker is necessary because synchronizing certificate maintenance
|
||||
// depends on how storage is implemented.
|
||||
Locker
|
||||
}
|
||||
|
||||
// ErrNotExist is returned by Storage implementations when
|
||||
// a resource is not found. It is similar to os.ErrNotExist
|
||||
// except this is a type, not a variable.
|
||||
type ErrNotExist interface {
|
||||
error
|
||||
}
|
||||
|
||||
// Waiter is a type that can block until a storage lock is released.
|
||||
type Waiter interface {
|
||||
Wait()
|
||||
}
|
||||
@@ -1,148 +0,0 @@
|
||||
// Copyright 2015 Light Code Labs, LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package storagetest
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/url"
|
||||
"sync"
|
||||
|
||||
"github.com/mholt/caddy/caddytls"
|
||||
)
|
||||
|
||||
// memoryMutex is a mutex used to control access to memoryStoragesByCAURL.
|
||||
var memoryMutex sync.Mutex
|
||||
|
||||
// memoryStoragesByCAURL is a map keyed by a CA URL string with values of
|
||||
// instantiated memory stores. Do not access this directly, it is used by
|
||||
// InMemoryStorageCreator.
|
||||
var memoryStoragesByCAURL = make(map[string]*InMemoryStorage)
|
||||
|
||||
// InMemoryStorageCreator is a caddytls.Storage.StorageCreator to create
|
||||
// InMemoryStorage instances for testing.
|
||||
func InMemoryStorageCreator(caURL *url.URL) (caddytls.Storage, error) {
|
||||
urlStr := caURL.String()
|
||||
memoryMutex.Lock()
|
||||
defer memoryMutex.Unlock()
|
||||
storage := memoryStoragesByCAURL[urlStr]
|
||||
if storage == nil {
|
||||
storage = NewInMemoryStorage()
|
||||
memoryStoragesByCAURL[urlStr] = storage
|
||||
}
|
||||
return storage, nil
|
||||
}
|
||||
|
||||
// InMemoryStorage is a caddytls.Storage implementation for use in testing.
|
||||
// It simply stores information in runtime memory.
|
||||
type InMemoryStorage struct {
|
||||
// Sites are exposed for testing purposes.
|
||||
Sites map[string]*caddytls.SiteData
|
||||
// Users are exposed for testing purposes.
|
||||
Users map[string]*caddytls.UserData
|
||||
// LastUserEmail is exposed for testing purposes.
|
||||
LastUserEmail string
|
||||
}
|
||||
|
||||
// NewInMemoryStorage constructs an InMemoryStorage instance. For use with
|
||||
// caddytls, the InMemoryStorageCreator should be used instead.
|
||||
func NewInMemoryStorage() *InMemoryStorage {
|
||||
return &InMemoryStorage{
|
||||
Sites: make(map[string]*caddytls.SiteData),
|
||||
Users: make(map[string]*caddytls.UserData),
|
||||
}
|
||||
}
|
||||
|
||||
// SiteExists implements caddytls.Storage.SiteExists in memory.
|
||||
func (s *InMemoryStorage) SiteExists(domain string) (bool, error) {
|
||||
_, siteExists := s.Sites[domain]
|
||||
return siteExists, nil
|
||||
}
|
||||
|
||||
// Clear completely clears all values associated with this storage.
|
||||
func (s *InMemoryStorage) Clear() {
|
||||
s.Sites = make(map[string]*caddytls.SiteData)
|
||||
s.Users = make(map[string]*caddytls.UserData)
|
||||
s.LastUserEmail = ""
|
||||
}
|
||||
|
||||
// LoadSite implements caddytls.Storage.LoadSite in memory.
|
||||
func (s *InMemoryStorage) LoadSite(domain string) (*caddytls.SiteData, error) {
|
||||
siteData, ok := s.Sites[domain]
|
||||
if !ok {
|
||||
return nil, caddytls.ErrNotExist(errors.New("not found"))
|
||||
}
|
||||
return siteData, nil
|
||||
}
|
||||
|
||||
func copyBytes(from []byte) []byte {
|
||||
copiedBytes := make([]byte, len(from))
|
||||
copy(copiedBytes, from)
|
||||
return copiedBytes
|
||||
}
|
||||
|
||||
// StoreSite implements caddytls.Storage.StoreSite in memory.
|
||||
func (s *InMemoryStorage) StoreSite(domain string, data *caddytls.SiteData) error {
|
||||
copiedData := new(caddytls.SiteData)
|
||||
copiedData.Cert = copyBytes(data.Cert)
|
||||
copiedData.Key = copyBytes(data.Key)
|
||||
copiedData.Meta = copyBytes(data.Meta)
|
||||
s.Sites[domain] = copiedData
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteSite implements caddytls.Storage.DeleteSite in memory.
|
||||
func (s *InMemoryStorage) DeleteSite(domain string) error {
|
||||
if _, ok := s.Sites[domain]; !ok {
|
||||
return caddytls.ErrNotExist(errors.New("not found"))
|
||||
}
|
||||
delete(s.Sites, domain)
|
||||
return nil
|
||||
}
|
||||
|
||||
// TryLock implements Storage.TryLock by returning nil values because it
|
||||
// is not a multi-server storage implementation.
|
||||
func (s *InMemoryStorage) TryLock(domain string) (caddytls.Waiter, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Unlock implements Storage.Unlock as a no-op because it is
|
||||
// not a multi-server storage implementation.
|
||||
func (s *InMemoryStorage) Unlock(domain string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadUser implements caddytls.Storage.LoadUser in memory.
|
||||
func (s *InMemoryStorage) LoadUser(email string) (*caddytls.UserData, error) {
|
||||
userData, ok := s.Users[email]
|
||||
if !ok {
|
||||
return nil, caddytls.ErrNotExist(errors.New("not found"))
|
||||
}
|
||||
return userData, nil
|
||||
}
|
||||
|
||||
// StoreUser implements caddytls.Storage.StoreUser in memory.
|
||||
func (s *InMemoryStorage) StoreUser(email string, data *caddytls.UserData) error {
|
||||
copiedData := new(caddytls.UserData)
|
||||
copiedData.Reg = copyBytes(data.Reg)
|
||||
copiedData.Key = copyBytes(data.Key)
|
||||
s.Users[email] = copiedData
|
||||
s.LastUserEmail = email
|
||||
return nil
|
||||
}
|
||||
|
||||
// MostRecentUserEmail implements caddytls.Storage.MostRecentUserEmail in memory.
|
||||
func (s *InMemoryStorage) MostRecentUserEmail() string {
|
||||
return s.LastUserEmail
|
||||
}
|
||||
@@ -1,26 +0,0 @@
|
||||
// Copyright 2015 Light Code Labs, LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package storagetest
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestMemoryStorage(t *testing.T) {
|
||||
storage := NewInMemoryStorage()
|
||||
storageTest := &StorageTest{
|
||||
Storage: storage,
|
||||
PostTest: storage.Clear,
|
||||
}
|
||||
storageTest.Test(t, false)
|
||||
}
|
||||
@@ -1,306 +0,0 @@
|
||||
// Copyright 2015 Light Code Labs, LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Package storagetest provides utilities to assist in testing caddytls.Storage
|
||||
// implementations.
|
||||
package storagetest
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/mholt/caddy/caddytls"
|
||||
)
|
||||
|
||||
// StorageTest is a test harness that contains tests to execute all exposed
|
||||
// parts of a Storage implementation.
|
||||
type StorageTest struct {
|
||||
// Storage is the implementation to use during tests. This must be
|
||||
// present.
|
||||
caddytls.Storage
|
||||
|
||||
// PreTest, if present, is called before every test. Any error returned
|
||||
// is returned from the test and the test does not continue.
|
||||
PreTest func() error
|
||||
|
||||
// PostTest, if present, is executed after every test via defer which
|
||||
// means it executes even on failure of the test (but not on failure of
|
||||
// PreTest).
|
||||
PostTest func()
|
||||
|
||||
// AfterUserEmailStore, if present, is invoked during
|
||||
// TestMostRecentUserEmail after each storage just in case anything
|
||||
// needs to be mocked.
|
||||
AfterUserEmailStore func(email string) error
|
||||
}
|
||||
|
||||
// TestFunc holds information about a test.
|
||||
type TestFunc struct {
|
||||
// Name is the friendly name of the test.
|
||||
Name string
|
||||
|
||||
// Fn is the function that is invoked for the test.
|
||||
Fn func() error
|
||||
}
|
||||
|
||||
// runPreTest runs the PreTest function if present.
|
||||
func (s *StorageTest) runPreTest() error {
|
||||
if s.PreTest != nil {
|
||||
return s.PreTest()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// runPostTest runs the PostTest function if present.
|
||||
func (s *StorageTest) runPostTest() {
|
||||
if s.PostTest != nil {
|
||||
s.PostTest()
|
||||
}
|
||||
}
|
||||
|
||||
// AllFuncs returns all test functions that are part of this harness.
|
||||
func (s *StorageTest) AllFuncs() []TestFunc {
|
||||
return []TestFunc{
|
||||
{"TestSiteInfoExists", s.TestSiteExists},
|
||||
{"TestSite", s.TestSite},
|
||||
{"TestUser", s.TestUser},
|
||||
{"TestMostRecentUserEmail", s.TestMostRecentUserEmail},
|
||||
}
|
||||
}
|
||||
|
||||
// Test executes the entire harness using the testing package. Failures are
|
||||
// reported via T.Fatal. If eagerFail is true, the first failure causes all
|
||||
// testing to stop immediately.
|
||||
func (s *StorageTest) Test(t *testing.T, eagerFail bool) {
|
||||
if errs := s.TestAll(eagerFail); len(errs) > 0 {
|
||||
ifaces := make([]interface{}, len(errs))
|
||||
for i, err := range errs {
|
||||
ifaces[i] = err
|
||||
}
|
||||
t.Fatal(ifaces...)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAll executes the entire harness and returns the results as an array of
|
||||
// errors. If eagerFail is true, the first failure causes all testing to stop
|
||||
// immediately.
|
||||
func (s *StorageTest) TestAll(eagerFail bool) (errs []error) {
|
||||
for _, fn := range s.AllFuncs() {
|
||||
if err := fn.Fn(); err != nil {
|
||||
errs = append(errs, fmt.Errorf("%v failed: %v", fn.Name, err))
|
||||
if eagerFail {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var simpleSiteData = &caddytls.SiteData{
|
||||
Cert: []byte("foo"),
|
||||
Key: []byte("bar"),
|
||||
Meta: []byte("baz"),
|
||||
}
|
||||
var simpleSiteDataAlt = &caddytls.SiteData{
|
||||
Cert: []byte("qux"),
|
||||
Key: []byte("quux"),
|
||||
Meta: []byte("corge"),
|
||||
}
|
||||
|
||||
// TestSiteExists tests Storage.SiteExists.
|
||||
func (s *StorageTest) TestSiteExists() error {
|
||||
if err := s.runPreTest(); err != nil {
|
||||
return err
|
||||
}
|
||||
defer s.runPostTest()
|
||||
|
||||
// Should not exist at first
|
||||
siteExists, err := s.SiteExists("example.com")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if siteExists {
|
||||
return errors.New("Site should not exist")
|
||||
}
|
||||
|
||||
// Should exist after we store it
|
||||
if err := s.StoreSite("example.com", simpleSiteData); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
siteExists, err = s.SiteExists("example.com")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !siteExists {
|
||||
return errors.New("Expected site to exist")
|
||||
}
|
||||
|
||||
// Site should no longer exist after we delete it
|
||||
if err := s.DeleteSite("example.com"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
siteExists, err = s.SiteExists("example.com")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if siteExists {
|
||||
return errors.New("Site should not exist after delete")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestSite tests Storage.LoadSite, Storage.StoreSite, and Storage.DeleteSite.
|
||||
func (s *StorageTest) TestSite() error {
|
||||
if err := s.runPreTest(); err != nil {
|
||||
return err
|
||||
}
|
||||
defer s.runPostTest()
|
||||
|
||||
// Should be a not-found error at first
|
||||
_, err := s.LoadSite("example.com")
|
||||
if _, ok := err.(caddytls.ErrNotExist); !ok {
|
||||
return fmt.Errorf("Expected caddytls.ErrNotExist from load, got %T: %v", err, err)
|
||||
}
|
||||
|
||||
// Delete should also be a not-found error at first
|
||||
err = s.DeleteSite("example.com")
|
||||
if _, ok := err.(caddytls.ErrNotExist); !ok {
|
||||
return fmt.Errorf("Expected ErrNotExist from delete, got: %v", err)
|
||||
}
|
||||
|
||||
// Should store successfully and then load just fine
|
||||
if err := s.StoreSite("example.com", simpleSiteData); err != nil {
|
||||
return err
|
||||
}
|
||||
if siteData, err := s.LoadSite("example.com"); err != nil {
|
||||
return err
|
||||
} else if !bytes.Equal(siteData.Cert, simpleSiteData.Cert) {
|
||||
return errors.New("Unexpected cert returned after store")
|
||||
} else if !bytes.Equal(siteData.Key, simpleSiteData.Key) {
|
||||
return errors.New("Unexpected key returned after store")
|
||||
} else if !bytes.Equal(siteData.Meta, simpleSiteData.Meta) {
|
||||
return errors.New("Unexpected meta returned after store")
|
||||
}
|
||||
|
||||
// Overwrite should work just fine
|
||||
if err := s.StoreSite("example.com", simpleSiteDataAlt); err != nil {
|
||||
return err
|
||||
}
|
||||
if siteData, err := s.LoadSite("example.com"); err != nil {
|
||||
return err
|
||||
} else if !bytes.Equal(siteData.Cert, simpleSiteDataAlt.Cert) {
|
||||
return errors.New("Unexpected cert returned after overwrite")
|
||||
}
|
||||
|
||||
// It should delete fine and then not be there
|
||||
if err := s.DeleteSite("example.com"); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = s.LoadSite("example.com")
|
||||
if _, ok := err.(caddytls.ErrNotExist); !ok {
|
||||
return fmt.Errorf("Expected caddytls.ErrNotExist after delete, got %T: %v", err, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var simpleUserData = &caddytls.UserData{
|
||||
Reg: []byte("foo"),
|
||||
Key: []byte("bar"),
|
||||
}
|
||||
var simpleUserDataAlt = &caddytls.UserData{
|
||||
Reg: []byte("baz"),
|
||||
Key: []byte("qux"),
|
||||
}
|
||||
|
||||
// TestUser tests Storage.LoadUser and Storage.StoreUser.
|
||||
func (s *StorageTest) TestUser() error {
|
||||
if err := s.runPreTest(); err != nil {
|
||||
return err
|
||||
}
|
||||
defer s.runPostTest()
|
||||
|
||||
// Should be a not-found error at first
|
||||
_, err := s.LoadUser("foo@example.com")
|
||||
if _, ok := err.(caddytls.ErrNotExist); !ok {
|
||||
return fmt.Errorf("Expected caddytls.ErrNotExist from load, got %T: %v", err, err)
|
||||
}
|
||||
|
||||
// Should store successfully and then load just fine
|
||||
if err := s.StoreUser("foo@example.com", simpleUserData); err != nil {
|
||||
return err
|
||||
}
|
||||
if userData, err := s.LoadUser("foo@example.com"); err != nil {
|
||||
return err
|
||||
} else if !bytes.Equal(userData.Reg, simpleUserData.Reg) {
|
||||
return errors.New("Unexpected reg returned after store")
|
||||
} else if !bytes.Equal(userData.Key, simpleUserData.Key) {
|
||||
return errors.New("Unexpected key returned after store")
|
||||
}
|
||||
|
||||
// Overwrite should work just fine
|
||||
if err := s.StoreUser("foo@example.com", simpleUserDataAlt); err != nil {
|
||||
return err
|
||||
}
|
||||
if userData, err := s.LoadUser("foo@example.com"); err != nil {
|
||||
return err
|
||||
} else if !bytes.Equal(userData.Reg, simpleUserDataAlt.Reg) {
|
||||
return errors.New("Unexpected reg returned after overwrite")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestMostRecentUserEmail tests Storage.MostRecentUserEmail.
|
||||
func (s *StorageTest) TestMostRecentUserEmail() error {
|
||||
if err := s.runPreTest(); err != nil {
|
||||
return err
|
||||
}
|
||||
defer s.runPostTest()
|
||||
|
||||
// Should be empty on first run
|
||||
if e := s.MostRecentUserEmail(); e != "" {
|
||||
return fmt.Errorf("Expected empty most recent user on first run, got: %v", e)
|
||||
}
|
||||
|
||||
// If we store user, then that one should be returned
|
||||
if err := s.StoreUser("foo1@example.com", simpleUserData); err != nil {
|
||||
return err
|
||||
}
|
||||
if s.AfterUserEmailStore != nil {
|
||||
s.AfterUserEmailStore("foo1@example.com")
|
||||
}
|
||||
if e := s.MostRecentUserEmail(); e != "foo1@example.com" {
|
||||
return fmt.Errorf("Unexpected most recent email after first store: %v", e)
|
||||
}
|
||||
|
||||
// If we store another user, then that one should be returned
|
||||
if err := s.StoreUser("foo2@example.com", simpleUserDataAlt); err != nil {
|
||||
return err
|
||||
}
|
||||
if s.AfterUserEmailStore != nil {
|
||||
s.AfterUserEmailStore("foo2@example.com")
|
||||
}
|
||||
if e := s.MostRecentUserEmail(); e != "foo2@example.com" {
|
||||
return fmt.Errorf("Unexpected most recent email after user key: %v", e)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,54 +0,0 @@
|
||||
// Copyright 2015 Light Code Labs, LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package storagetest
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/mholt/caddy/caddytls"
|
||||
)
|
||||
|
||||
// TestFileStorage tests the file storage set with the test harness in this
|
||||
// package.
|
||||
func TestFileStorage(t *testing.T) {
|
||||
emailCounter := 0
|
||||
storageTest := &StorageTest{
|
||||
Storage: &caddytls.FileStorage{Path: "./testdata"}, // nameLocks isn't made here, but it's okay because the tests don't call TryLock or Unlock
|
||||
PostTest: func() { os.RemoveAll("./testdata") },
|
||||
AfterUserEmailStore: func(email string) error {
|
||||
// We need to change the dir mod time to show a
|
||||
// that certain dirs are newer.
|
||||
emailCounter++
|
||||
fp := filepath.Join("./testdata", "users", email)
|
||||
|
||||
// What we will do is subtract 10 days from today and
|
||||
// then add counter * seconds to make the later
|
||||
// counters newer. We accept that this isn't exactly
|
||||
// how the file storage works because it only changes
|
||||
// timestamps on *newly seen* users, but it achieves
|
||||
// the result that the harness expects.
|
||||
chTime := time.Now().AddDate(0, 0, -10).Add(time.Duration(emailCounter) * time.Second)
|
||||
if err := os.Chtimes(fp, chTime, chTime); err != nil {
|
||||
return fmt.Errorf("Unable to change file time for %v: %v", fp, err)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
storageTest.Test(t, false)
|
||||
}
|
||||
+23
-128
@@ -15,7 +15,7 @@
|
||||
// Package caddytls facilitates the management of TLS assets and integrates
|
||||
// Let's Encrypt functionality into Caddy with first-class support for
|
||||
// creating and renewing certificates automatically. It also implements
|
||||
// the tls directive.
|
||||
// the tls directive. It's mostly powered by the CertMagic package.
|
||||
//
|
||||
// This package is meant to be used by Caddy server types. To use the
|
||||
// tls directive, a server type must import this package and call
|
||||
@@ -29,100 +29,11 @@
|
||||
package caddytls
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/xenolf/lego/acme"
|
||||
"github.com/mholt/certmagic"
|
||||
"github.com/xenolf/lego/challenge"
|
||||
)
|
||||
|
||||
// HostQualifies returns true if the hostname alone
|
||||
// appears eligible for automatic HTTPS. For example,
|
||||
// localhost, empty hostname, and IP addresses are
|
||||
// not eligible because we cannot obtain certificates
|
||||
// for those names.
|
||||
func HostQualifies(hostname string) bool {
|
||||
return hostname != "localhost" && // localhost is ineligible
|
||||
|
||||
// hostname must not be empty
|
||||
strings.TrimSpace(hostname) != "" &&
|
||||
|
||||
// must not contain wildcard (*) characters (until CA supports it)
|
||||
!strings.Contains(hostname, "*") &&
|
||||
|
||||
// must not start or end with a dot
|
||||
!strings.HasPrefix(hostname, ".") &&
|
||||
!strings.HasSuffix(hostname, ".") &&
|
||||
|
||||
// cannot be an IP address, see
|
||||
// https://community.letsencrypt.org/t/certificate-for-static-ip/84/2?u=mholt
|
||||
net.ParseIP(hostname) == nil
|
||||
}
|
||||
|
||||
// saveCertResource saves the certificate resource to disk. This
|
||||
// includes the certificate file itself, the private key, and the
|
||||
// metadata file.
|
||||
func saveCertResource(storage Storage, cert acme.CertificateResource) error {
|
||||
// Save cert, private key, and metadata
|
||||
siteData := &SiteData{
|
||||
Cert: cert.Certificate,
|
||||
Key: cert.PrivateKey,
|
||||
}
|
||||
var err error
|
||||
siteData.Meta, err = json.MarshalIndent(&cert, "", "\t")
|
||||
if err == nil {
|
||||
err = storage.StoreSite(cert.Domain, siteData)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Revoke revokes the certificate for host via ACME protocol.
|
||||
// It assumes the certificate was obtained from the
|
||||
// CA at DefaultCAUrl.
|
||||
func Revoke(host string) error {
|
||||
client, err := newACMEClient(new(Config), true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return client.Revoke(host)
|
||||
}
|
||||
|
||||
// tlsSNISolver is a type that can solve TLS-SNI challenges using
|
||||
// an existing listener and our custom, in-memory certificate cache.
|
||||
type tlsSNISolver struct {
|
||||
certCache *certificateCache
|
||||
}
|
||||
|
||||
// Present adds the challenge certificate to the cache.
|
||||
func (s tlsSNISolver) Present(domain, token, keyAuth string) error {
|
||||
cert, acmeDomain, err := acme.TLSSNI01ChallengeCert(keyAuth)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
certHash := hashCertificateChain(cert.Certificate)
|
||||
s.certCache.Lock()
|
||||
s.certCache.cache[acmeDomain] = Certificate{
|
||||
Certificate: cert,
|
||||
Names: []string{acmeDomain},
|
||||
Hash: certHash, // perhaps not necesssary
|
||||
}
|
||||
s.certCache.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanUp removes the challenge certificate from the cache.
|
||||
func (s tlsSNISolver) CleanUp(domain, token, keyAuth string) error {
|
||||
_, acmeDomain, err := acme.TLSSNI01ChallengeCert(keyAuth)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.certCache.Lock()
|
||||
delete(s.certCache.cache, acmeDomain)
|
||||
s.certCache.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConfigHolder is any type that has a Config; it presumably is
|
||||
// connected to a hostname and port on which it is serving.
|
||||
type ConfigHolder interface {
|
||||
@@ -144,11 +55,12 @@ func QualifiesForManagedTLS(c ConfigHolder) bool {
|
||||
return false
|
||||
}
|
||||
tlsConfig := c.TLSConfig()
|
||||
if tlsConfig == nil {
|
||||
if tlsConfig == nil || tlsConfig.Manager == nil {
|
||||
return false
|
||||
}
|
||||
onDemand := tlsConfig.Manager.OnDemand != nil
|
||||
|
||||
return (!tlsConfig.Manual || tlsConfig.OnDemand) && // user might provide own cert and key
|
||||
return (!tlsConfig.Manual || onDemand) && // user might provide own cert and key
|
||||
|
||||
// if self-signed, we've already generated one to use
|
||||
!tlsConfig.SelfSigned &&
|
||||
@@ -159,17 +71,30 @@ func QualifiesForManagedTLS(c ConfigHolder) bool {
|
||||
|
||||
// we get can't certs for some kinds of hostnames, but
|
||||
// on-demand TLS allows empty hostnames at startup
|
||||
(HostQualifies(c.Host()) || tlsConfig.OnDemand)
|
||||
(certmagic.HostQualifies(c.Host()) || onDemand)
|
||||
}
|
||||
|
||||
// Revoke revokes the certificate fro host via the ACME protocol.
|
||||
// It assumes the certificate was obtained from certmagic.CA.
|
||||
func Revoke(domainName string) error {
|
||||
return certmagic.NewDefault().RevokeCert(domainName, true)
|
||||
}
|
||||
|
||||
// KnownACMECAs is a list of ACME directory endpoints of
|
||||
// known, public, and trusted ACME-compatible certificate
|
||||
// authorities.
|
||||
var KnownACMECAs = []string{
|
||||
"https://acme-v02.api.letsencrypt.org/directory",
|
||||
}
|
||||
|
||||
// ChallengeProvider defines an own type that should be used in Caddy plugins
|
||||
// over acme.ChallengeProvider. Using acme.ChallengeProvider causes version mismatches
|
||||
// over challenge.Provider. Using challenge.Provider causes version mismatches
|
||||
// with vendored dependencies (see https://github.com/mattfarina/golang-broken-vendor)
|
||||
//
|
||||
// acme.ChallengeProvider is an interface that allows the implementation of custom
|
||||
// challenge.Provider is an interface that allows the implementation of custom
|
||||
// challenge providers. For more details, see:
|
||||
// https://godoc.org/github.com/xenolf/lego/acme#ChallengeProvider
|
||||
type ChallengeProvider acme.ChallengeProvider
|
||||
type ChallengeProvider challenge.Provider
|
||||
|
||||
// DNSProviderConstructor is a function that takes credentials and
|
||||
// returns a type that can solve the ACME DNS challenges.
|
||||
@@ -183,33 +108,3 @@ func RegisterDNSProvider(name string, provider DNSProviderConstructor) {
|
||||
dnsProviders[name] = provider
|
||||
caddy.RegisterPlugin("tls.dns."+name, caddy.Plugin{})
|
||||
}
|
||||
|
||||
var (
|
||||
// DefaultEmail represents the Let's Encrypt account email to use if none provided.
|
||||
DefaultEmail string
|
||||
|
||||
// Agreed indicates whether user has agreed to the Let's Encrypt SA.
|
||||
Agreed bool
|
||||
|
||||
// DefaultCAUrl is the default URL to the CA's ACME directory endpoint.
|
||||
// It's very important to set this unless you set it in every Config.
|
||||
DefaultCAUrl string
|
||||
|
||||
// DefaultKeyType is used as the type of key for new certificates
|
||||
// when no other key type is specified.
|
||||
DefaultKeyType = acme.RSA2048
|
||||
|
||||
// DisableHTTPChallenge will disable all HTTP challenges.
|
||||
DisableHTTPChallenge bool
|
||||
|
||||
// DisableTLSSNIChallenge will disable all TLS-SNI challenges.
|
||||
DisableTLSSNIChallenge bool
|
||||
)
|
||||
|
||||
var storageProviders = make(map[string]StorageConstructor)
|
||||
|
||||
// RegisterStorageProvider registers provider by name for storing tls data
|
||||
func RegisterStorageProvider(name string, provider StorageConstructor) {
|
||||
storageProviders[name] = provider
|
||||
caddy.RegisterPlugin("tls.storage."+name, caddy.Plugin{})
|
||||
}
|
||||
|
||||
+11
-128
@@ -15,46 +15,11 @@
|
||||
package caddytls
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/xenolf/lego/acme"
|
||||
"github.com/mholt/certmagic"
|
||||
)
|
||||
|
||||
func TestHostQualifies(t *testing.T) {
|
||||
for i, test := range []struct {
|
||||
host string
|
||||
expect bool
|
||||
}{
|
||||
{"example.com", true},
|
||||
{"sub.example.com", true},
|
||||
{"Sub.Example.COM", true},
|
||||
{"127.0.0.1", false},
|
||||
{"127.0.1.5", false},
|
||||
{"69.123.43.94", false},
|
||||
{"::1", false},
|
||||
{"::", false},
|
||||
{"0.0.0.0", false},
|
||||
{"", false},
|
||||
{" ", false},
|
||||
{"*.example.com", false},
|
||||
{".com", false},
|
||||
{"example.com.", false},
|
||||
{"localhost", false},
|
||||
{"local", true},
|
||||
{"devsite", true},
|
||||
{"192.168.1.3", false},
|
||||
{"10.0.2.1", false},
|
||||
{"169.112.53.4", false},
|
||||
} {
|
||||
actual := HostQualifies(test.host)
|
||||
if actual != test.expect {
|
||||
t.Errorf("Test %d: Expected HostQualifies(%s)=%v, but got %v",
|
||||
i, test.host, test.expect, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type holder struct {
|
||||
host, port string
|
||||
cfg *Config
|
||||
@@ -76,14 +41,17 @@ func TestQualifiesForManagedTLS(t *testing.T) {
|
||||
{holder{host: "", cfg: new(Config)}, false},
|
||||
{holder{host: "localhost", cfg: new(Config)}, false},
|
||||
{holder{host: "123.44.3.21", cfg: new(Config)}, false},
|
||||
{holder{host: "example.com", cfg: new(Config)}, true},
|
||||
{holder{host: "*.example.com", cfg: new(Config)}, false},
|
||||
{holder{host: "example.com", cfg: &Config{Manual: true}}, false},
|
||||
{holder{host: "example.com", cfg: &Config{ACMEEmail: "off"}}, false},
|
||||
{holder{host: "example.com", cfg: &Config{ACMEEmail: "foo@bar.com"}}, true},
|
||||
{holder{host: "example.com", cfg: &Config{Manager: &certmagic.Config{}}}, true},
|
||||
{holder{host: "*.example.com", cfg: &Config{Manager: &certmagic.Config{}}}, true},
|
||||
{holder{host: "*.*.example.com", cfg: new(Config)}, false},
|
||||
{holder{host: "*sub.example.com", cfg: new(Config)}, false},
|
||||
{holder{host: "sub.*.example.com", cfg: new(Config)}, false},
|
||||
{holder{host: "example.com", cfg: &Config{Manager: &certmagic.Config{}, Manual: true}}, false},
|
||||
{holder{host: "example.com", cfg: &Config{Manager: &certmagic.Config{}, ACMEEmail: "off"}}, false},
|
||||
{holder{host: "example.com", cfg: &Config{Manager: &certmagic.Config{}, ACMEEmail: "foo@bar.com"}}, true},
|
||||
{holder{host: "example.com", port: "80"}, false},
|
||||
{holder{host: "example.com", port: "1234", cfg: new(Config)}, true},
|
||||
{holder{host: "example.com", port: "443", cfg: new(Config)}, true},
|
||||
{holder{host: "example.com", port: "1234", cfg: &Config{Manager: &certmagic.Config{}}}, true},
|
||||
{holder{host: "example.com", port: "443", cfg: &Config{Manager: &certmagic.Config{}}}, true},
|
||||
{holder{host: "example.com", port: "80"}, false},
|
||||
} {
|
||||
if got, want := QualifiesForManagedTLS(test.cfg), test.expect; got != want {
|
||||
@@ -91,88 +59,3 @@ func TestQualifiesForManagedTLS(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveCertResource(t *testing.T) {
|
||||
storage := &FileStorage{Path: "./le_test_save"}
|
||||
defer func() {
|
||||
err := os.RemoveAll(storage.Path)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not remove temporary storage directory (%s): %v", storage.Path, err)
|
||||
}
|
||||
}()
|
||||
|
||||
domain := "example.com"
|
||||
certContents := "certificate"
|
||||
keyContents := "private key"
|
||||
metaContents := `{
|
||||
"domain": "example.com",
|
||||
"certUrl": "https://example.com/cert",
|
||||
"certStableUrl": "https://example.com/cert/stable"
|
||||
}`
|
||||
|
||||
cert := acme.CertificateResource{
|
||||
Domain: domain,
|
||||
CertURL: "https://example.com/cert",
|
||||
CertStableURL: "https://example.com/cert/stable",
|
||||
PrivateKey: []byte(keyContents),
|
||||
Certificate: []byte(certContents),
|
||||
}
|
||||
|
||||
err := saveCertResource(storage, cert)
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
siteData, err := storage.LoadSite(domain)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error reading site, got: %v", err)
|
||||
}
|
||||
if string(siteData.Cert) != certContents {
|
||||
t.Errorf("Expected certificate file to contain '%s', got '%s'", certContents, string(siteData.Cert))
|
||||
}
|
||||
if string(siteData.Key) != keyContents {
|
||||
t.Errorf("Expected private key file to contain '%s', got '%s'", keyContents, string(siteData.Key))
|
||||
}
|
||||
if string(siteData.Meta) != metaContents {
|
||||
t.Errorf("Expected meta file to contain '%s', got '%s'", metaContents, string(siteData.Meta))
|
||||
}
|
||||
}
|
||||
|
||||
func TestExistingCertAndKey(t *testing.T) {
|
||||
storage := &FileStorage{Path: "./le_test_existing"}
|
||||
defer func() {
|
||||
err := os.RemoveAll(storage.Path)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not remove temporary storage directory (%s): %v", storage.Path, err)
|
||||
}
|
||||
}()
|
||||
|
||||
domain := "example.com"
|
||||
|
||||
siteExists, err := storage.SiteExists(domain)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not determine whether site exists: %v", err)
|
||||
}
|
||||
|
||||
if siteExists {
|
||||
t.Errorf("Did NOT expect %v to have existing cert or key, but it did", domain)
|
||||
}
|
||||
|
||||
err = saveCertResource(storage, acme.CertificateResource{
|
||||
Domain: domain,
|
||||
PrivateKey: []byte("key"),
|
||||
Certificate: []byte("cert"),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got: %v", err)
|
||||
}
|
||||
|
||||
siteExists, err = storage.SiteExists(domain)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not determine whether site exists: %v", err)
|
||||
}
|
||||
|
||||
if !siteExists {
|
||||
t.Errorf("Expected %v to have existing cert and key, but it did NOT", domain)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,190 +0,0 @@
|
||||
// Copyright 2015 Light Code Labs, LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package caddytls
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/xenolf/lego/acme"
|
||||
)
|
||||
|
||||
// User represents a Let's Encrypt user account.
|
||||
type User struct {
|
||||
Email string
|
||||
Registration *acme.RegistrationResource
|
||||
key crypto.PrivateKey
|
||||
}
|
||||
|
||||
// GetEmail gets u's email.
|
||||
func (u User) GetEmail() string {
|
||||
return u.Email
|
||||
}
|
||||
|
||||
// GetRegistration gets u's registration resource.
|
||||
func (u User) GetRegistration() *acme.RegistrationResource {
|
||||
return u.Registration
|
||||
}
|
||||
|
||||
// GetPrivateKey gets u's private key.
|
||||
func (u User) GetPrivateKey() crypto.PrivateKey {
|
||||
return u.key
|
||||
}
|
||||
|
||||
// newUser creates a new User for the given email address
|
||||
// with a new private key. This function does NOT save the
|
||||
// user to disk or register it via ACME. If you want to use
|
||||
// a user account that might already exist, call getUser
|
||||
// instead. It does NOT prompt the user.
|
||||
func newUser(email string) (User, error) {
|
||||
user := User{Email: email}
|
||||
privateKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
|
||||
if err != nil {
|
||||
return user, errors.New("error generating private key: " + err.Error())
|
||||
}
|
||||
user.key = privateKey
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// getEmail does everything it can to obtain an email
|
||||
// address from the user within the scope of storage
|
||||
// to use for ACME TLS. If it cannot get an email
|
||||
// address, it returns empty string. (It will warn the
|
||||
// user of the consequences of an empty email.) This
|
||||
// function MAY prompt the user for input. If userPresent
|
||||
// is false, the operator will NOT be prompted and an
|
||||
// empty email may be returned.
|
||||
func getEmail(storage Storage, userPresent bool) string {
|
||||
// First try memory (command line flag or typed by user previously)
|
||||
leEmail := DefaultEmail
|
||||
if leEmail == "" {
|
||||
// Then try to get most recent user email
|
||||
leEmail = storage.MostRecentUserEmail()
|
||||
// Save for next time
|
||||
DefaultEmail = leEmail
|
||||
}
|
||||
if leEmail == "" && userPresent {
|
||||
// Alas, we must bother the user and ask for an email address;
|
||||
// if they proceed they also agree to the SA.
|
||||
reader := bufio.NewReader(stdin)
|
||||
fmt.Println("\nYour sites will be served over HTTPS automatically using Let's Encrypt.")
|
||||
fmt.Println("By continuing, you agree to the Let's Encrypt Subscriber Agreement at:")
|
||||
fmt.Println(" " + saURL) // TODO: Show current SA link
|
||||
fmt.Println("Please enter your email address so you can recover your account if needed.")
|
||||
fmt.Println("You can leave it blank, but you'll lose the ability to recover your account.")
|
||||
fmt.Print("Email address: ")
|
||||
var err error
|
||||
leEmail, err = reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
leEmail = strings.TrimSpace(leEmail)
|
||||
DefaultEmail = leEmail
|
||||
Agreed = true
|
||||
}
|
||||
return strings.ToLower(leEmail)
|
||||
}
|
||||
|
||||
// getUser loads the user with the given email from disk
|
||||
// using the provided storage. If the user does not exist,
|
||||
// it will create a new one, but it does NOT save new
|
||||
// users to the disk or register them via ACME. It does
|
||||
// NOT prompt the user.
|
||||
func getUser(storage Storage, email string) (User, error) {
|
||||
var user User
|
||||
|
||||
// open user reg
|
||||
userData, err := storage.LoadUser(email)
|
||||
if err != nil {
|
||||
if _, ok := err.(ErrNotExist); ok {
|
||||
// create a new user
|
||||
return newUser(email)
|
||||
}
|
||||
return user, err
|
||||
}
|
||||
|
||||
// load user information
|
||||
err = json.Unmarshal(userData.Reg, &user)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
// load their private key
|
||||
user.key, err = loadPrivateKey(userData.Key)
|
||||
return user, err
|
||||
}
|
||||
|
||||
// saveUser persists a user's key and account registration
|
||||
// to the file system. It does NOT register the user via ACME
|
||||
// or prompt the user. You must also pass in the storage
|
||||
// wherein the user should be saved. It should be the storage
|
||||
// for the CA with which user has an account.
|
||||
func saveUser(storage Storage, user User) error {
|
||||
// Save the private key and registration
|
||||
userData := new(UserData)
|
||||
var err error
|
||||
userData.Key, err = savePrivateKey(user.key)
|
||||
if err == nil {
|
||||
userData.Reg, err = json.MarshalIndent(&user, "", "\t")
|
||||
}
|
||||
if err == nil {
|
||||
err = storage.StoreUser(user.Email, userData)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// promptUserAgreement prompts the user to agree to the agreement
|
||||
// at agreementURL via stdin. If the agreement has changed, then pass
|
||||
// true as the second argument. If this is the user's first time
|
||||
// agreeing, pass false. It returns whether the user agreed or not.
|
||||
func promptUserAgreement(agreementURL string, changed bool) bool {
|
||||
if changed {
|
||||
fmt.Printf("The Let's Encrypt Subscriber Agreement has changed:\n %s\n", agreementURL)
|
||||
fmt.Print("Do you agree to the new terms? (y/n): ")
|
||||
} else {
|
||||
fmt.Printf("To continue, you must agree to the Let's Encrypt Subscriber Agreement:\n %s\n", agreementURL)
|
||||
fmt.Print("Do you agree to the terms? (y/n): ")
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(stdin)
|
||||
answer, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
answer = strings.ToLower(strings.TrimSpace(answer))
|
||||
|
||||
return answer == "y" || answer == "yes"
|
||||
}
|
||||
|
||||
// stdin is used to read the user's input if prompted;
|
||||
// this is changed by tests during tests.
|
||||
var stdin = io.ReadWriter(os.Stdin)
|
||||
|
||||
// The name of the folder for accounts where the email
|
||||
// address was not provided; default 'username' if you will.
|
||||
const emptyEmail = "default"
|
||||
|
||||
// TODO: After Boulder implements the 'meta' field of the directory,
|
||||
// we can get this link dynamically.
|
||||
const saURL = "https://acme-v01.api.letsencrypt.org/terms"
|
||||
@@ -1,202 +0,0 @@
|
||||
// Copyright 2015 Light Code Labs, LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package caddytls
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"os"
|
||||
|
||||
"github.com/xenolf/lego/acme"
|
||||
)
|
||||
|
||||
func TestUser(t *testing.T) {
|
||||
defer testStorage.clean()
|
||||
|
||||
privateKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not generate test private key: %v", err)
|
||||
}
|
||||
u := User{
|
||||
Email: "me@mine.com",
|
||||
Registration: new(acme.RegistrationResource),
|
||||
key: privateKey,
|
||||
}
|
||||
|
||||
if expected, actual := "me@mine.com", u.GetEmail(); actual != expected {
|
||||
t.Errorf("Expected email '%s' but got '%s'", expected, actual)
|
||||
}
|
||||
if u.GetRegistration() == nil {
|
||||
t.Error("Expected a registration resource, but got nil")
|
||||
}
|
||||
if expected, actual := privateKey, u.GetPrivateKey(); actual != expected {
|
||||
t.Errorf("Expected the private key at address %p but got one at %p instead ", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewUser(t *testing.T) {
|
||||
email := "me@foobar.com"
|
||||
user, err := newUser(email)
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating user: %v", err)
|
||||
}
|
||||
if user.key == nil {
|
||||
t.Error("Private key is nil")
|
||||
}
|
||||
if user.Email != email {
|
||||
t.Errorf("Expected email to be %s, but was %s", email, user.Email)
|
||||
}
|
||||
if user.Registration != nil {
|
||||
t.Error("New user already has a registration resource; it shouldn't")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveUser(t *testing.T) {
|
||||
defer testStorage.clean()
|
||||
|
||||
email := "me@foobar.com"
|
||||
user, err := newUser(email)
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating user: %v", err)
|
||||
}
|
||||
|
||||
err = saveUser(testStorage, user)
|
||||
if err != nil {
|
||||
t.Fatalf("Error saving user: %v", err)
|
||||
}
|
||||
_, err = testStorage.LoadUser(email)
|
||||
if err != nil {
|
||||
t.Errorf("Cannot access user data, error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUserDoesNotAlreadyExist(t *testing.T) {
|
||||
defer testStorage.clean()
|
||||
|
||||
user, err := getUser(testStorage, "user_does_not_exist@foobar.com")
|
||||
if err != nil {
|
||||
t.Fatalf("Error getting user: %v", err)
|
||||
}
|
||||
|
||||
if user.key == nil {
|
||||
t.Error("Expected user to have a private key, but it was nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUserAlreadyExists(t *testing.T) {
|
||||
defer testStorage.clean()
|
||||
|
||||
email := "me@foobar.com"
|
||||
|
||||
// Set up test
|
||||
user, err := newUser(email)
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating user: %v", err)
|
||||
}
|
||||
err = saveUser(testStorage, user)
|
||||
if err != nil {
|
||||
t.Fatalf("Error saving user: %v", err)
|
||||
}
|
||||
|
||||
// Expect to load user from disk
|
||||
user2, err := getUser(testStorage, email)
|
||||
if err != nil {
|
||||
t.Fatalf("Error getting user: %v", err)
|
||||
}
|
||||
|
||||
// Assert keys are the same
|
||||
if !PrivateKeysSame(user.key, user2.key) {
|
||||
t.Error("Expected private key to be the same after loading, but it wasn't")
|
||||
}
|
||||
|
||||
// Assert emails are the same
|
||||
if user.Email != user2.Email {
|
||||
t.Errorf("Expected emails to be equal, but was '%s' before and '%s' after loading", user.Email, user2.Email)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetEmail(t *testing.T) {
|
||||
storageBasePath = testStorage.Path // to contain calls that create a new Storage...
|
||||
|
||||
// let's not clutter up the output
|
||||
origStdout := os.Stdout
|
||||
os.Stdout = nil
|
||||
defer func() { os.Stdout = origStdout }()
|
||||
|
||||
defer testStorage.clean()
|
||||
DefaultEmail = "test2@foo.com"
|
||||
|
||||
// Test1: Use default email from flag (or user previously typing it)
|
||||
actual := getEmail(testStorage, true)
|
||||
if actual != DefaultEmail {
|
||||
t.Errorf("Did not get correct email from memory; expected '%s' but got '%s'", DefaultEmail, actual)
|
||||
}
|
||||
|
||||
// Test2: Get input from user
|
||||
DefaultEmail = ""
|
||||
stdin = new(bytes.Buffer)
|
||||
_, err := io.Copy(stdin, strings.NewReader("test3@foo.com\n"))
|
||||
if err != nil {
|
||||
t.Fatalf("Could not simulate user input, error: %v", err)
|
||||
}
|
||||
actual = getEmail(testStorage, true)
|
||||
if actual != "test3@foo.com" {
|
||||
t.Errorf("Did not get correct email from user input prompt; expected '%s' but got '%s'", "test3@foo.com", actual)
|
||||
}
|
||||
|
||||
// Test3: Get most recent email from before
|
||||
DefaultEmail = ""
|
||||
for i, eml := range []string{
|
||||
"TEST4-3@foo.com", // test case insensitivity
|
||||
"test4-2@foo.com",
|
||||
"test4-1@foo.com",
|
||||
} {
|
||||
u, err := newUser(eml)
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating user %d: %v", i, err)
|
||||
}
|
||||
err = saveUser(testStorage, u)
|
||||
if err != nil {
|
||||
t.Fatalf("Error saving user %d: %v", i, err)
|
||||
}
|
||||
|
||||
// Change modified time so they're all different and the test becomes more deterministic
|
||||
f, err := os.Stat(testStorage.user(eml))
|
||||
if err != nil {
|
||||
t.Fatalf("Could not access user folder for '%s': %v", eml, err)
|
||||
}
|
||||
chTime := f.ModTime().Add(-(time.Duration(i) * time.Hour)) // 1 second isn't always enough space!
|
||||
if err := os.Chtimes(testStorage.user(eml), chTime, chTime); err != nil {
|
||||
t.Fatalf("Could not change user folder mod time for '%s': %v", eml, err)
|
||||
}
|
||||
}
|
||||
actual = getEmail(testStorage, true)
|
||||
if actual != "test4-3@foo.com" {
|
||||
t.Errorf("Did not get correct email from storage; expected '%s' but got '%s'", "test4-3@foo.com", actual)
|
||||
}
|
||||
}
|
||||
|
||||
var testStorage = &FileStorage{Path: "./testdata"}
|
||||
|
||||
func (s *FileStorage) clean() error {
|
||||
return os.RemoveAll(s.Path)
|
||||
}
|
||||
+11
-5
@@ -71,31 +71,37 @@ func (c *Controller) ServerType() string {
|
||||
// OnFirstStartup adds fn to the list of callback functions to execute
|
||||
// when the server is about to be started NOT as part of a restart.
|
||||
func (c *Controller) OnFirstStartup(fn func() error) {
|
||||
c.instance.onFirstStartup = append(c.instance.onFirstStartup, fn)
|
||||
c.instance.OnFirstStartup = append(c.instance.OnFirstStartup, fn)
|
||||
}
|
||||
|
||||
// OnStartup adds fn to the list of callback functions to execute
|
||||
// when the server is about to be started (including restarts).
|
||||
func (c *Controller) OnStartup(fn func() error) {
|
||||
c.instance.onStartup = append(c.instance.onStartup, fn)
|
||||
c.instance.OnStartup = append(c.instance.OnStartup, fn)
|
||||
}
|
||||
|
||||
// OnRestart adds fn to the list of callback functions to execute
|
||||
// when the server is about to be restarted.
|
||||
func (c *Controller) OnRestart(fn func() error) {
|
||||
c.instance.onRestart = append(c.instance.onRestart, fn)
|
||||
c.instance.OnRestart = append(c.instance.OnRestart, fn)
|
||||
}
|
||||
|
||||
// OnRestartFailed adds fn to the list of callback functions to execute
|
||||
// if the server failed to restart.
|
||||
func (c *Controller) OnRestartFailed(fn func() error) {
|
||||
c.instance.OnRestartFailed = append(c.instance.OnRestartFailed, fn)
|
||||
}
|
||||
|
||||
// OnShutdown adds fn to the list of callback functions to execute
|
||||
// when the server is about to be shut down (including restarts).
|
||||
func (c *Controller) OnShutdown(fn func() error) {
|
||||
c.instance.onShutdown = append(c.instance.onShutdown, fn)
|
||||
c.instance.OnShutdown = append(c.instance.OnShutdown, fn)
|
||||
}
|
||||
|
||||
// OnFinalShutdown adds fn to the list of callback functions to execute
|
||||
// when the server is about to be shut down NOT as part of a restart.
|
||||
func (c *Controller) OnFinalShutdown(fn func() error) {
|
||||
c.instance.onFinalShutdown = append(c.instance.onFinalShutdown, fn)
|
||||
c.instance.OnFinalShutdown = append(c.instance.OnFinalShutdown, fn)
|
||||
}
|
||||
|
||||
// Context gets the context associated with the instance associated with c.
|
||||
|
||||
Vendored
+536
-465
File diff suppressed because it is too large
Load Diff
Vendored
+539
-539
File diff suppressed because it is too large
Load Diff
Vendored
+40
-40
@@ -1,40 +1,40 @@
|
||||
CADDY 0.10.11
|
||||
|
||||
Website
|
||||
https://caddyserver.com
|
||||
|
||||
Community Forum
|
||||
https://caddy.community
|
||||
|
||||
Twitter
|
||||
@caddyserver
|
||||
|
||||
Source Code
|
||||
https://github.com/mholt/caddy
|
||||
https://github.com/caddyserver
|
||||
|
||||
|
||||
For instructions on using Caddy, please see the docs on the
|
||||
website. For a list of what's new in this version, see
|
||||
CHANGES.txt.
|
||||
|
||||
For a good time, follow @mholt6 on Twitter.
|
||||
|
||||
Want to get involved with Caddy's development? We love to have
|
||||
contributions! Please file an issue on GitHub to discuss a
|
||||
change or fix you'd like to make, then submit a pull request
|
||||
and we'll review it! Your contributions will reach millions
|
||||
of people who connect to sites served by Caddy.
|
||||
|
||||
Extend Caddy by developing a plugin for it! Instructions on
|
||||
the project wiki: https://github.com/mholt/caddy/wiki
|
||||
|
||||
And thanks - you're awesome!
|
||||
|
||||
If you think Caddy is awesome too, consider sponsoring it:
|
||||
https://caddyserver.com/sponsor - and help keep Caddy free
|
||||
for personal use.
|
||||
|
||||
|
||||
---
|
||||
(c) 2015-2018 Light Code Labs, LLC
|
||||
CADDY 0.11.2
|
||||
|
||||
Website
|
||||
https://caddyserver.com
|
||||
|
||||
Community Forum
|
||||
https://caddy.community
|
||||
|
||||
Twitter
|
||||
@caddyserver
|
||||
|
||||
Source Code
|
||||
https://github.com/mholt/caddy
|
||||
https://github.com/caddyserver
|
||||
|
||||
|
||||
For instructions on using Caddy, please see the docs on the
|
||||
website. For a list of what's new in this version, see
|
||||
CHANGES.txt.
|
||||
|
||||
For a good time, follow @mholt6 on Twitter.
|
||||
|
||||
Want to get involved with Caddy's development? We love to have
|
||||
contributions! Please file an issue on GitHub to discuss a
|
||||
change or fix you'd like to make, then submit a pull request
|
||||
and we'll review it! Your contributions will reach millions
|
||||
of people who connect to sites served by Caddy.
|
||||
|
||||
Extend Caddy by developing a plugin for it! Instructions on
|
||||
the project wiki: https://github.com/mholt/caddy/wiki
|
||||
|
||||
And thanks - you're awesome!
|
||||
|
||||
If you think Caddy is awesome too, consider sponsoring it:
|
||||
https://caddyserver.com/sponsor - and help keep Caddy free
|
||||
for personal use.
|
||||
|
||||
|
||||
---
|
||||
(c) 2015-2019 Light Code Labs, LLC
|
||||
|
||||
Vendored
+3
-3
@@ -44,7 +44,7 @@ sudo useradd \
|
||||
--system --uid 33 www-data
|
||||
|
||||
sudo mkdir /etc/caddy
|
||||
sudo chown -R root:www-data /etc/caddy
|
||||
sudo chown -R root:root /etc/caddy
|
||||
sudo mkdir /etc/ssl/caddy
|
||||
sudo chown -R root:www-data /etc/ssl/caddy
|
||||
sudo chmod 0770 /etc/ssl/caddy
|
||||
@@ -55,8 +55,8 @@ and give it appropriate ownership and permissions:
|
||||
|
||||
```bash
|
||||
sudo cp /path/to/Caddyfile /etc/caddy/
|
||||
sudo chown www-data:www-data /etc/caddy/Caddyfile
|
||||
sudo chmod 444 /etc/caddy/Caddyfile
|
||||
sudo chown root:root /etc/caddy/Caddyfile
|
||||
sudo chmod 644 /etc/caddy/Caddyfile
|
||||
```
|
||||
|
||||
Create the home directory for the server and give it appropriate ownership
|
||||
|
||||
+102
-18
@@ -22,6 +22,7 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/mholt/caddy/caddyfile"
|
||||
"github.com/mholt/certmagic"
|
||||
)
|
||||
|
||||
// These are all the registered plugins.
|
||||
@@ -39,7 +40,7 @@ var (
|
||||
|
||||
// eventHooks is a map of hook name to Hook. All hooks plugins
|
||||
// must have a name.
|
||||
eventHooks = sync.Map{}
|
||||
eventHooks = &sync.Map{}
|
||||
|
||||
// parsingCallbacks maps server type to map of directive
|
||||
// to list of callback functions. These aren't really
|
||||
@@ -54,32 +55,70 @@ var (
|
||||
|
||||
// DescribePlugins returns a string describing the registered plugins.
|
||||
func DescribePlugins() string {
|
||||
pl := ListPlugins()
|
||||
|
||||
str := "Server types:\n"
|
||||
for name := range serverTypes {
|
||||
for _, name := range pl["server_types"] {
|
||||
str += " " + name + "\n"
|
||||
}
|
||||
|
||||
// List the loaders in registration order
|
||||
str += "\nCaddyfile loaders:\n"
|
||||
for _, name := range pl["caddyfile_loaders"] {
|
||||
str += " " + name + "\n"
|
||||
}
|
||||
|
||||
if len(pl["event_hooks"]) > 0 {
|
||||
str += "\nEvent hook plugins:\n"
|
||||
for _, name := range pl["event_hooks"] {
|
||||
str += " hook." + name + "\n"
|
||||
}
|
||||
}
|
||||
|
||||
if len(pl["clustering"]) > 0 {
|
||||
str += "\nClustering plugins:\n"
|
||||
for _, name := range pl["clustering"] {
|
||||
str += " " + name + "\n"
|
||||
}
|
||||
}
|
||||
|
||||
str += "\nOther plugins:\n"
|
||||
for _, name := range pl["others"] {
|
||||
str += " " + name + "\n"
|
||||
}
|
||||
|
||||
return str
|
||||
}
|
||||
|
||||
// ListPlugins makes a list of the registered plugins,
|
||||
// keyed by plugin type.
|
||||
func ListPlugins() map[string][]string {
|
||||
p := make(map[string][]string)
|
||||
|
||||
// server type plugins
|
||||
for name := range serverTypes {
|
||||
p["server_types"] = append(p["server_types"], name)
|
||||
}
|
||||
|
||||
// caddyfile loaders in registration order
|
||||
for _, loader := range caddyfileLoaders {
|
||||
str += " " + loader.name + "\n"
|
||||
p["caddyfile_loaders"] = append(p["caddyfile_loaders"], loader.name)
|
||||
}
|
||||
if defaultCaddyfileLoader.name != "" {
|
||||
str += " " + defaultCaddyfileLoader.name + "\n"
|
||||
p["caddyfile_loaders"] = append(p["caddyfile_loaders"], defaultCaddyfileLoader.name)
|
||||
}
|
||||
|
||||
// cluster plugins in registration order
|
||||
for name := range clusterProviders {
|
||||
p["clustering"] = append(p["clsutering"], name)
|
||||
}
|
||||
|
||||
// List the event hook plugins
|
||||
hooks := ""
|
||||
eventHooks.Range(func(k, _ interface{}) bool {
|
||||
hooks += " hook." + k.(string) + "\n"
|
||||
p["event_hooks"] = append(p["event_hooks"], k.(string))
|
||||
return true
|
||||
})
|
||||
if hooks != "" {
|
||||
str += "\nEvent hook plugins:\n"
|
||||
str += hooks
|
||||
}
|
||||
|
||||
// Let's alphabetize the rest of these...
|
||||
// alphabetize the rest of the plugins
|
||||
var others []string
|
||||
for stype, stypePlugins := range plugins {
|
||||
for name := range stypePlugins {
|
||||
@@ -93,12 +132,11 @@ func DescribePlugins() string {
|
||||
}
|
||||
|
||||
sort.Strings(others)
|
||||
str += "\nOther plugins:\n"
|
||||
for _, name := range others {
|
||||
str += " " + name + "\n"
|
||||
p["others"] = append(p["others"], name)
|
||||
}
|
||||
|
||||
return str
|
||||
return p
|
||||
}
|
||||
|
||||
// ValidDirectives returns the list of all directives that are
|
||||
@@ -238,9 +276,10 @@ type EventName string
|
||||
// Define names for the various events
|
||||
const (
|
||||
StartupEvent EventName = "startup"
|
||||
ShutdownEvent EventName = "shutdown"
|
||||
CertRenewEvent EventName = "certrenew"
|
||||
InstanceStartupEvent EventName = "instancestartup"
|
||||
ShutdownEvent = "shutdown"
|
||||
CertRenewEvent = "certrenew"
|
||||
InstanceStartupEvent = "instancestartup"
|
||||
InstanceRestartEvent = "instancerestart"
|
||||
)
|
||||
|
||||
// EventHook is a type which holds information about a startup hook plugin.
|
||||
@@ -271,6 +310,36 @@ func EmitEvent(event EventName, info interface{}) {
|
||||
})
|
||||
}
|
||||
|
||||
// cloneEventHooks return a clone of the event hooks *sync.Map
|
||||
func cloneEventHooks() *sync.Map {
|
||||
c := &sync.Map{}
|
||||
eventHooks.Range(func(k, v interface{}) bool {
|
||||
c.Store(k, v)
|
||||
return true
|
||||
})
|
||||
return c
|
||||
}
|
||||
|
||||
// purgeEventHooks purges all event hooks from the map
|
||||
func purgeEventHooks() {
|
||||
eventHooks.Range(func(k, _ interface{}) bool {
|
||||
eventHooks.Delete(k)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// restoreEventHooks restores eventHooks with a provided *sync.Map
|
||||
func restoreEventHooks(m *sync.Map) {
|
||||
// Purge old event hooks
|
||||
purgeEventHooks()
|
||||
|
||||
// Restore event hooks
|
||||
m.Range(func(k, v interface{}) bool {
|
||||
eventHooks.Store(k, v)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// ParsingCallback is a function that is called after
|
||||
// a directive's setup functions have been executed
|
||||
// for all the server blocks.
|
||||
@@ -387,6 +456,21 @@ func loadCaddyfileInput(serverType string) (Input, error) {
|
||||
return caddyfileToUse, nil
|
||||
}
|
||||
|
||||
// ClusterPluginConstructor is a function type that is used to
|
||||
// instantiate a new implementation of both certmagic.Storage
|
||||
// and certmagic.Locker, which are required for successful
|
||||
// use in cluster environments.
|
||||
type ClusterPluginConstructor func() (certmagic.Storage, error)
|
||||
|
||||
// clusterProviders is the list of storage providers
|
||||
var clusterProviders = make(map[string]ClusterPluginConstructor)
|
||||
|
||||
// RegisterClusterPlugin registers provider by name for facilitating
|
||||
// cluster-wide operations like storage and synchronization.
|
||||
func RegisterClusterPlugin(name string, provider ClusterPluginConstructor) {
|
||||
clusterProviders[name] = provider
|
||||
}
|
||||
|
||||
// OnProcessExit is a list of functions to run when the process
|
||||
// exits -- they are ONLY for cleanup and should not block,
|
||||
// return errors, or do anything fancy. They will be run with
|
||||
|
||||
+1
-1
@@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// +build windows plan9 nacl
|
||||
// +build windows plan9 nacl js
|
||||
|
||||
package caddy
|
||||
|
||||
|
||||
+2
-2
@@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// +build !windows,!plan9,!nacl
|
||||
// +build !windows,!plan9,!nacl,!js
|
||||
|
||||
package caddy
|
||||
|
||||
@@ -31,7 +31,7 @@ func checkFdlimit() {
|
||||
err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, rlimit)
|
||||
if err == nil && rlimit.Cur < min {
|
||||
fmt.Printf("WARNING: File descriptor limit %d is too low for production servers. "+
|
||||
"At least %d is recommended. Fix with \"ulimit -n %d\".\n", rlimit.Cur, min, min)
|
||||
"At least %d is recommended. Fix with `ulimit -n %d`.\n", rlimit.Cur, min, min)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -19,6 +19,8 @@ import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
|
||||
"github.com/mholt/caddy/telemetry"
|
||||
)
|
||||
|
||||
// TrapSignals create signal handlers for all applicable signals for this
|
||||
@@ -52,6 +54,9 @@ func trapSignalsCrossPlatform() {
|
||||
|
||||
log.Println("[INFO] SIGINT: Shutting down")
|
||||
|
||||
telemetry.AppendUnique("sigtrap", "SIGINT")
|
||||
go telemetry.StopEmitting() // not guaranteed to finish in time; that's OK (just don't block!)
|
||||
|
||||
// important cleanup actions before shutdown callbacks
|
||||
for _, f := range OnProcessExit {
|
||||
f()
|
||||
|
||||
+1
-1
@@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// +build windows plan9 nacl
|
||||
// +build windows plan9 nacl js
|
||||
|
||||
package caddy
|
||||
|
||||
|
||||
+19
-1
@@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// +build !windows,!plan9,!nacl
|
||||
// +build !windows,!plan9,!nacl,!js
|
||||
|
||||
package caddy
|
||||
|
||||
@@ -21,6 +21,8 @@ import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/mholt/caddy/telemetry"
|
||||
)
|
||||
|
||||
// trapSignalsPosix captures POSIX-only signals.
|
||||
@@ -49,10 +51,15 @@ func trapSignalsPosix() {
|
||||
log.Printf("[ERROR] SIGTERM stop: %v", err)
|
||||
exitCode = 3
|
||||
}
|
||||
|
||||
telemetry.AppendUnique("sigtrap", "SIGTERM")
|
||||
go telemetry.StopEmitting() // won't finish in time, but that's OK - just don't block
|
||||
|
||||
os.Exit(exitCode)
|
||||
|
||||
case syscall.SIGUSR1:
|
||||
log.Println("[INFO] SIGUSR1: Reloading")
|
||||
go telemetry.AppendUnique("sigtrap", "SIGUSR1")
|
||||
|
||||
// Start with the existing Caddyfile
|
||||
caddyfileToUse, inst, err := getCurrentCaddyfile()
|
||||
@@ -76,20 +83,31 @@ func trapSignalsPosix() {
|
||||
caddyfileToUse = newCaddyfile
|
||||
}
|
||||
|
||||
// Backup old event hooks
|
||||
oldEventHooks := cloneEventHooks()
|
||||
|
||||
// Purge the old event hooks
|
||||
purgeEventHooks()
|
||||
|
||||
// Kick off the restart; our work is done
|
||||
EmitEvent(InstanceRestartEvent, nil)
|
||||
_, err = inst.Restart(caddyfileToUse)
|
||||
if err != nil {
|
||||
restoreEventHooks(oldEventHooks)
|
||||
|
||||
log.Printf("[ERROR] SIGUSR1: %v", err)
|
||||
}
|
||||
|
||||
case syscall.SIGUSR2:
|
||||
log.Println("[INFO] SIGUSR2: Upgrading")
|
||||
go telemetry.AppendUnique("sigtrap", "SIGUSR2")
|
||||
if err := Upgrade(); err != nil {
|
||||
log.Printf("[ERROR] SIGUSR2: upgrading: %v", err)
|
||||
}
|
||||
|
||||
case syscall.SIGHUP:
|
||||
// ignore; this signal is sometimes sent outside of the user's control
|
||||
go telemetry.AppendUnique("sigtrap", "SIGHUP")
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -1,98 +0,0 @@
|
||||
// Copyright 2015 Light Code Labs, LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package startupshutdown
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/mholt/caddy"
|
||||
"github.com/mholt/caddy/onevent/hook"
|
||||
)
|
||||
|
||||
func init() {
|
||||
caddy.RegisterPlugin("startup", caddy.Plugin{Action: Startup})
|
||||
caddy.RegisterPlugin("shutdown", caddy.Plugin{Action: Shutdown})
|
||||
}
|
||||
|
||||
// Startup (an alias for 'on startup') registers a startup callback to execute during server start.
|
||||
func Startup(c *caddy.Controller) error {
|
||||
config, err := onParse(c, caddy.InstanceStartupEvent)
|
||||
if err != nil {
|
||||
return c.ArgErr()
|
||||
}
|
||||
|
||||
// Register Event Hooks.
|
||||
c.OncePerServerBlock(func() error {
|
||||
for _, cfg := range config {
|
||||
caddy.RegisterEventHook("on-"+cfg.ID, cfg.Hook)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
fmt.Println("NOTICE: Startup directive will be removed in a later version. Please migrate to 'on startup'")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shutdown (an alias for 'on shutdown') registers a shutdown callback to execute during server start.
|
||||
func Shutdown(c *caddy.Controller) error {
|
||||
config, err := onParse(c, caddy.ShutdownEvent)
|
||||
if err != nil {
|
||||
return c.ArgErr()
|
||||
}
|
||||
|
||||
// Register Event Hooks.
|
||||
for _, cfg := range config {
|
||||
caddy.RegisterEventHook("on-"+cfg.ID, cfg.Hook)
|
||||
}
|
||||
|
||||
fmt.Println("NOTICE: Shutdown directive will be removed in a later version. Please migrate to 'on shutdown'")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func onParse(c *caddy.Controller, event caddy.EventName) ([]*hook.Config, error) {
|
||||
var config []*hook.Config
|
||||
|
||||
for c.Next() {
|
||||
cfg := new(hook.Config)
|
||||
|
||||
args := c.RemainingArgs()
|
||||
if len(args) == 0 {
|
||||
return config, c.ArgErr()
|
||||
}
|
||||
|
||||
// Configure Event.
|
||||
cfg.Event = event
|
||||
|
||||
// Assign an unique ID.
|
||||
cfg.ID = uuid.New().String()
|
||||
|
||||
// Extract command and arguments.
|
||||
command, args, err := caddy.SplitCommandAndArgs(strings.Join(args, " "))
|
||||
if err != nil {
|
||||
return config, c.Err(err.Error())
|
||||
}
|
||||
|
||||
cfg.Command = command
|
||||
cfg.Args = args
|
||||
|
||||
config = append(config, cfg)
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
@@ -1,69 +0,0 @@
|
||||
// Copyright 2015 Light Code Labs, LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package startupshutdown
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/mholt/caddy"
|
||||
)
|
||||
|
||||
func TestStartup(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
shouldErr bool
|
||||
}{
|
||||
{name: "noInput", input: "startup", shouldErr: true},
|
||||
{name: "startup", input: "startup cmd arg", shouldErr: false},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
c := caddy.NewTestController("", test.input)
|
||||
|
||||
err := Startup(c)
|
||||
if err == nil && test.shouldErr {
|
||||
t.Error("Test didn't error, but it should have")
|
||||
} else if err != nil && !test.shouldErr {
|
||||
t.Errorf("Test errored, but it shouldn't have; got '%v'", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShutdown(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
shouldErr bool
|
||||
}{
|
||||
{name: "noInput", input: "shutdown", shouldErr: true},
|
||||
{name: "shutdown", input: "shutdown cmd arg", shouldErr: false},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
c := caddy.NewTestController("", test.input)
|
||||
|
||||
err := Shutdown(c)
|
||||
if err == nil && test.shouldErr {
|
||||
t.Error("Test didn't error, but it should have")
|
||||
} else if err != nil && !test.shouldErr {
|
||||
t.Errorf("Test errored, but it shouldn't have; got '%v'", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,307 @@
|
||||
// Copyright 2015 Light Code Labs, LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package telemetry
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// Init initializes this package so that it may
|
||||
// be used. Do not call this function more than
|
||||
// once. Init panics if it is called more than
|
||||
// once or if the UUID value is empty. Once this
|
||||
// function is called, the rest of the package
|
||||
// may safely be used. If this function is not
|
||||
// called, the collector functions may still be
|
||||
// invoked, but they will be no-ops.
|
||||
//
|
||||
// Any metrics keys that are passed in the second
|
||||
// argument will be permanently disabled for the
|
||||
// lifetime of the process.
|
||||
func Init(instanceID uuid.UUID, disabledMetricsKeys []string) {
|
||||
if enabled {
|
||||
panic("already initialized")
|
||||
}
|
||||
if str := instanceID.String(); str == "" ||
|
||||
str == "00000000-0000-0000-0000-000000000000" {
|
||||
panic("empty UUID")
|
||||
}
|
||||
instanceUUID = instanceID
|
||||
disabledMetricsMu.Lock()
|
||||
for _, key := range disabledMetricsKeys {
|
||||
disabledMetrics[strings.TrimSpace(key)] = false
|
||||
}
|
||||
disabledMetricsMu.Unlock()
|
||||
enabled = true
|
||||
}
|
||||
|
||||
// StartEmitting sends the current payload and begins the
|
||||
// transmission cycle for updates. This is the first
|
||||
// update sent, and future ones will be sent until
|
||||
// StopEmitting is called.
|
||||
//
|
||||
// This function is non-blocking (it spawns a new goroutine).
|
||||
//
|
||||
// This function panics if it was called more than once.
|
||||
// It is a no-op if this package was not initialized.
|
||||
func StartEmitting() {
|
||||
if !enabled {
|
||||
return
|
||||
}
|
||||
updateTimerMu.Lock()
|
||||
if updateTimer != nil {
|
||||
updateTimerMu.Unlock()
|
||||
panic("updates already started")
|
||||
}
|
||||
updateTimerMu.Unlock()
|
||||
updateMu.Lock()
|
||||
if updating {
|
||||
updateMu.Unlock()
|
||||
panic("update already in progress")
|
||||
}
|
||||
updateMu.Unlock()
|
||||
go logEmit(false)
|
||||
}
|
||||
|
||||
// StopEmitting sends the current payload and terminates
|
||||
// the update cycle. No more updates will be sent.
|
||||
//
|
||||
// It is a no-op if the package was never initialized
|
||||
// or if emitting was never started.
|
||||
//
|
||||
// NOTE: This function is blocking. Run in a goroutine if
|
||||
// you want to guarantee no blocking at critical times
|
||||
// like exiting the program.
|
||||
func StopEmitting() {
|
||||
if !enabled {
|
||||
return
|
||||
}
|
||||
updateTimerMu.Lock()
|
||||
if updateTimer == nil {
|
||||
updateTimerMu.Unlock()
|
||||
return
|
||||
}
|
||||
updateTimerMu.Unlock()
|
||||
logEmit(true) // likely too early; may take minutes to return
|
||||
}
|
||||
|
||||
// Reset empties the current payload buffer.
|
||||
func Reset() {
|
||||
resetBuffer()
|
||||
}
|
||||
|
||||
// Set puts a value in the buffer to be included
|
||||
// in the next emission. It overwrites any
|
||||
// previous value.
|
||||
//
|
||||
// This function is safe for multiple goroutines,
|
||||
// and it is recommended to call this using the
|
||||
// go keyword after the call to SendHello so it
|
||||
// doesn't block crucial code.
|
||||
func Set(key string, val interface{}) {
|
||||
if !enabled || isDisabled(key) {
|
||||
return
|
||||
}
|
||||
bufferMu.Lock()
|
||||
if _, ok := buffer[key]; !ok {
|
||||
if bufferItemCount >= maxBufferItems {
|
||||
bufferMu.Unlock()
|
||||
return
|
||||
}
|
||||
bufferItemCount++
|
||||
}
|
||||
buffer[key] = val
|
||||
bufferMu.Unlock()
|
||||
}
|
||||
|
||||
// SetNested puts a value in the buffer to be included
|
||||
// in the next emission, nested under the top-level key
|
||||
// as subkey. It overwrites any previous value.
|
||||
//
|
||||
// This function is safe for multiple goroutines,
|
||||
// and it is recommended to call this using the
|
||||
// go keyword after the call to SendHello so it
|
||||
// doesn't block crucial code.
|
||||
func SetNested(key, subkey string, val interface{}) {
|
||||
if !enabled || isDisabled(key) {
|
||||
return
|
||||
}
|
||||
bufferMu.Lock()
|
||||
if topLevel, ok1 := buffer[key]; ok1 {
|
||||
topLevelMap, ok2 := topLevel.(map[string]interface{})
|
||||
if !ok2 {
|
||||
bufferMu.Unlock()
|
||||
log.Printf("[PANIC] Telemetry: key %s is already used for non-nested-map value", key)
|
||||
return
|
||||
}
|
||||
if _, ok3 := topLevelMap[subkey]; !ok3 {
|
||||
// don't exceed max buffer size
|
||||
if bufferItemCount >= maxBufferItems {
|
||||
bufferMu.Unlock()
|
||||
return
|
||||
}
|
||||
bufferItemCount++
|
||||
}
|
||||
topLevelMap[subkey] = val
|
||||
} else {
|
||||
// don't exceed max buffer size
|
||||
if bufferItemCount >= maxBufferItems {
|
||||
bufferMu.Unlock()
|
||||
return
|
||||
}
|
||||
bufferItemCount++
|
||||
buffer[key] = map[string]interface{}{subkey: val}
|
||||
}
|
||||
bufferMu.Unlock()
|
||||
}
|
||||
|
||||
// Append appends value to a list named key.
|
||||
// If key is new, a new list will be created.
|
||||
// If key maps to a type that is not a list,
|
||||
// a panic is logged, and this is a no-op.
|
||||
func Append(key string, value interface{}) {
|
||||
if !enabled || isDisabled(key) {
|
||||
return
|
||||
}
|
||||
bufferMu.Lock()
|
||||
if bufferItemCount >= maxBufferItems {
|
||||
bufferMu.Unlock()
|
||||
return
|
||||
}
|
||||
// TODO: Test this...
|
||||
bufVal, inBuffer := buffer[key]
|
||||
sliceVal, sliceOk := bufVal.([]interface{})
|
||||
if inBuffer && !sliceOk {
|
||||
bufferMu.Unlock()
|
||||
log.Printf("[PANIC] Telemetry: key %s already used for non-slice value", key)
|
||||
return
|
||||
}
|
||||
if sliceVal == nil {
|
||||
buffer[key] = []interface{}{value}
|
||||
} else if sliceOk {
|
||||
buffer[key] = append(sliceVal, value)
|
||||
}
|
||||
bufferItemCount++
|
||||
bufferMu.Unlock()
|
||||
}
|
||||
|
||||
// AppendUnique adds value to a set named key.
|
||||
// Set items are unordered. Values in the set
|
||||
// are unique, but how many times they are
|
||||
// appended is counted. The value must be
|
||||
// hashable.
|
||||
//
|
||||
// If key is new, a new set will be created for
|
||||
// values with that key. If key maps to a type
|
||||
// that is not a counting set, a panic is logged,
|
||||
// and this is a no-op.
|
||||
func AppendUnique(key string, value interface{}) {
|
||||
if !enabled || isDisabled(key) {
|
||||
return
|
||||
}
|
||||
bufferMu.Lock()
|
||||
bufVal, inBuffer := buffer[key]
|
||||
setVal, setOk := bufVal.(countingSet)
|
||||
if inBuffer && !setOk {
|
||||
bufferMu.Unlock()
|
||||
log.Printf("[PANIC] Telemetry: key %s already used for non-counting-set value", key)
|
||||
return
|
||||
}
|
||||
if setVal == nil {
|
||||
// ensure the buffer is not too full, then add new unique value
|
||||
if bufferItemCount >= maxBufferItems {
|
||||
bufferMu.Unlock()
|
||||
return
|
||||
}
|
||||
buffer[key] = countingSet{value: 1}
|
||||
bufferItemCount++
|
||||
} else if setOk {
|
||||
// unique value already exists, so just increment counter
|
||||
setVal[value]++
|
||||
}
|
||||
bufferMu.Unlock()
|
||||
}
|
||||
|
||||
// Add adds amount to a value named key.
|
||||
// If it does not exist, it is created with
|
||||
// a value of 1. If key maps to a type that
|
||||
// is not an integer, a panic is logged,
|
||||
// and this is a no-op.
|
||||
func Add(key string, amount int) {
|
||||
atomicAdd(key, amount)
|
||||
}
|
||||
|
||||
// Increment is a shortcut for Add(key, 1)
|
||||
func Increment(key string) {
|
||||
atomicAdd(key, 1)
|
||||
}
|
||||
|
||||
// atomicAdd adds amount (negative to subtract)
|
||||
// to key.
|
||||
func atomicAdd(key string, amount int) {
|
||||
if !enabled || isDisabled(key) {
|
||||
return
|
||||
}
|
||||
bufferMu.Lock()
|
||||
bufVal, inBuffer := buffer[key]
|
||||
intVal, intOk := bufVal.(int)
|
||||
if inBuffer && !intOk {
|
||||
bufferMu.Unlock()
|
||||
log.Printf("[PANIC] Telemetry: key %s already used for non-integer value", key)
|
||||
return
|
||||
}
|
||||
if !inBuffer {
|
||||
if bufferItemCount >= maxBufferItems {
|
||||
bufferMu.Unlock()
|
||||
return
|
||||
}
|
||||
bufferItemCount++
|
||||
}
|
||||
buffer[key] = intVal + amount
|
||||
bufferMu.Unlock()
|
||||
}
|
||||
|
||||
// FastHash hashes input using a 32-bit hashing algorithm
|
||||
// that is fast, and returns the hash as a hex-encoded string.
|
||||
// Do not use this for cryptographic purposes.
|
||||
func FastHash(input []byte) string {
|
||||
h := fnv.New32a()
|
||||
h.Write(input)
|
||||
return fmt.Sprintf("%x", h.Sum32())
|
||||
}
|
||||
|
||||
// isDisabled returns whether key is
|
||||
// a disabled metric key. ALL collection
|
||||
// functions should call this and not
|
||||
// save the value if this returns true.
|
||||
func isDisabled(key string) bool {
|
||||
// for keys that are augmented with data, such as
|
||||
// "tls_client_hello_ua:<hash>", just
|
||||
// check the prefix "tls_client_hello_ua"
|
||||
checkKey := key
|
||||
if idx := strings.Index(key, ":"); idx > -1 {
|
||||
checkKey = key[:idx]
|
||||
}
|
||||
|
||||
disabledMetricsMu.RLock()
|
||||
_, ok := disabledMetrics[checkKey]
|
||||
disabledMetricsMu.RUnlock()
|
||||
return ok
|
||||
}
|
||||
@@ -0,0 +1,109 @@
|
||||
// Copyright 2015 Light Code Labs, LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package telemetry
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func TestInit(t *testing.T) {
|
||||
reset()
|
||||
|
||||
id := doInit(t) // should not panic
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Errorf("Second call to Init should have panicked")
|
||||
}
|
||||
}()
|
||||
Init(id, nil) // should panic
|
||||
}
|
||||
|
||||
func TestInitEmptyUUID(t *testing.T) {
|
||||
reset()
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Errorf("Call to Init with empty UUID should have panicked")
|
||||
}
|
||||
}()
|
||||
Init(uuid.UUID([16]byte{}), nil)
|
||||
}
|
||||
|
||||
func TestSet(t *testing.T) {
|
||||
reset()
|
||||
|
||||
// should be no-op since we haven't called Init() yet
|
||||
Set("test1", "foobar")
|
||||
if _, ok := buffer["test"]; ok {
|
||||
t.Errorf("Should not have inserted item when not initialized")
|
||||
}
|
||||
|
||||
// should work after we've initialized
|
||||
doInit(t)
|
||||
Set("test1", "foobar")
|
||||
val, ok := buffer["test1"]
|
||||
if !ok {
|
||||
t.Errorf("Expected value to be in buffer, but it wasn't")
|
||||
} else if val.(string) != "foobar" {
|
||||
t.Errorf("Expected 'foobar', got '%v'", val)
|
||||
}
|
||||
|
||||
// should not overfill buffer
|
||||
maxBufferItemsTmp := maxBufferItems
|
||||
maxBufferItems = 10
|
||||
for i := 0; i < maxBufferItems+1; i++ {
|
||||
Set(fmt.Sprintf("overfill_%d", i), "foobar")
|
||||
}
|
||||
if len(buffer) > maxBufferItems {
|
||||
t.Errorf("Should not exceed max buffer size (%d); has %d items",
|
||||
maxBufferItems, len(buffer))
|
||||
}
|
||||
maxBufferItems = maxBufferItemsTmp
|
||||
|
||||
// Should overwrite values
|
||||
Set("test1", "foobar2")
|
||||
val, ok = buffer["test1"]
|
||||
if !ok {
|
||||
t.Errorf("Expected value to be in buffer, but it wasn't")
|
||||
} else if val.(string) != "foobar2" {
|
||||
t.Errorf("Expected 'foobar2', got '%v'", val)
|
||||
}
|
||||
}
|
||||
|
||||
// doInit calls Init() with a valid UUID
|
||||
// and returns it.
|
||||
func doInit(t *testing.T) uuid.UUID {
|
||||
id, err := uuid.Parse(testUUID)
|
||||
if err != nil {
|
||||
t.Fatalf("Could not make UUID: %v", err)
|
||||
}
|
||||
Init(id, nil)
|
||||
return id
|
||||
}
|
||||
|
||||
// reset resets all the lovely package-level state;
|
||||
// can be used as a set up function in tests.
|
||||
func reset() {
|
||||
instanceUUID = uuid.UUID{}
|
||||
buffer = make(map[string]interface{})
|
||||
bufferItemCount = 0
|
||||
updating = false
|
||||
enabled = false
|
||||
}
|
||||
|
||||
const testUUID = "0b6cfa22-0d4c-11e8-b11b-7a0058e13201"
|
||||
@@ -0,0 +1,428 @@
|
||||
// Copyright 2015 Light Code Labs, LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Package telemetry implements the client for server-side telemetry
|
||||
// of the network. Functions in this package are synchronous and blocking
|
||||
// unless otherwise specified. For convenience, most functions here do
|
||||
// not return errors, but errors are logged to the standard logger.
|
||||
//
|
||||
// To use this package, first call Init(). You can then call any of the
|
||||
// collection/aggregation functions. Call StartEmitting() when you are
|
||||
// ready to begin sending telemetry updates.
|
||||
//
|
||||
// When collecting metrics (functions like Set, AppendUnique, or Increment),
|
||||
// it may be desirable and even recommended to invoke them in a new
|
||||
// goroutine in case there is lock contention; they are thread-safe (unless
|
||||
// noted), and you may not want them to block the main thread of execution.
|
||||
// However, sometimes blocking may be necessary too; for example, adding
|
||||
// startup metrics to the buffer before the call to StartEmitting().
|
||||
//
|
||||
// This package is designed to be as fast and space-efficient as reasonably
|
||||
// possible, so that it does not disrupt the flow of execution.
|
||||
package telemetry
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// logEmit calls emit and then logs the error, if any.
|
||||
// See docs for emit.
|
||||
func logEmit(final bool) {
|
||||
err := emit(final)
|
||||
if err != nil {
|
||||
log.Printf("[ERROR] Sending telemetry: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// emit sends an update to the telemetry server.
|
||||
// Set final to true if this is the last call to emit.
|
||||
// If final is true, no future updates will be scheduled.
|
||||
// Otherwise, the next update will be scheduled.
|
||||
func emit(final bool) error {
|
||||
if !enabled {
|
||||
return fmt.Errorf("telemetry not enabled")
|
||||
}
|
||||
|
||||
// some metrics are updated/set at time of emission
|
||||
setEmitTimeMetrics()
|
||||
|
||||
// ensure only one update happens at a time;
|
||||
// skip update if previous one still in progress
|
||||
updateMu.Lock()
|
||||
if updating {
|
||||
updateMu.Unlock()
|
||||
log.Println("[NOTICE] Skipping this telemetry update because previous one is still working")
|
||||
return nil
|
||||
}
|
||||
updating = true
|
||||
updateMu.Unlock()
|
||||
defer func() {
|
||||
updateMu.Lock()
|
||||
updating = false
|
||||
updateMu.Unlock()
|
||||
}()
|
||||
|
||||
// terminate any pending update if this is the last one
|
||||
if final {
|
||||
stopUpdateTimer()
|
||||
}
|
||||
|
||||
payloadBytes, err := makePayloadAndResetBuffer()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// this will hold the server's reply
|
||||
var reply Response
|
||||
|
||||
// transmit the payload - use a loop to retry in case of failure
|
||||
for i := 0; i < 4; i++ {
|
||||
if i > 0 && err != nil {
|
||||
// don't hammer the server; first failure might have been
|
||||
// a fluke, but back off more after that
|
||||
log.Printf("[WARNING] Sending telemetry (attempt %d): %v - backing off and retrying", i, err)
|
||||
time.Sleep(time.Duration((i+1)*(i+1)*(i+1)) * time.Second)
|
||||
}
|
||||
|
||||
// send it
|
||||
var resp *http.Response
|
||||
resp, err = httpClient.Post(endpoint+instanceUUID.String(), "application/json", bytes.NewReader(payloadBytes))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// check for any special-case response codes
|
||||
if resp.StatusCode == http.StatusGone {
|
||||
// the endpoint has been deprecated and is no longer servicing clients
|
||||
err = fmt.Errorf("telemetry server replied with HTTP %d; upgrade required", resp.StatusCode)
|
||||
if clen := resp.Header.Get("Content-Length"); clen != "0" && clen != "" {
|
||||
bodyBytes, readErr := ioutil.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
log.Printf("[ERROR] Reading response body from server: %v", readErr)
|
||||
}
|
||||
err = fmt.Errorf("%v - %s", err, bodyBytes)
|
||||
}
|
||||
resp.Body.Close()
|
||||
reply.Stop = true
|
||||
break
|
||||
}
|
||||
if resp.StatusCode == http.StatusUnavailableForLegalReasons {
|
||||
// the endpoint is unavailable, at least to this client, for legal reasons (!)
|
||||
err = fmt.Errorf("telemetry server replied with HTTP %d %s: please consult the project website and developers for guidance", resp.StatusCode, resp.Status)
|
||||
if clen := resp.Header.Get("Content-Length"); clen != "0" && clen != "" {
|
||||
bodyBytes, readErr := ioutil.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
log.Printf("[ERROR] Reading response body from server: %v", readErr)
|
||||
}
|
||||
err = fmt.Errorf("%v - %s", err, bodyBytes)
|
||||
}
|
||||
resp.Body.Close()
|
||||
reply.Stop = true
|
||||
break
|
||||
}
|
||||
|
||||
// okay, ensure we can interpret the response
|
||||
if ct := resp.Header.Get("Content-Type"); (resp.StatusCode < 300 || resp.StatusCode >= 400) &&
|
||||
!strings.Contains(ct, "json") {
|
||||
err = fmt.Errorf("telemetry server replied with unknown content-type: '%s' and HTTP %s", ct, resp.Status)
|
||||
resp.Body.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
// read the response body
|
||||
err = json.NewDecoder(resp.Body).Decode(&reply)
|
||||
resp.Body.Close() // close response body as soon as we're done with it
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// update the list of enabled/disabled keys, if any
|
||||
for _, key := range reply.EnableKeys {
|
||||
disabledMetricsMu.Lock()
|
||||
// only re-enable this metric if it is temporarily disabled
|
||||
if temp, ok := disabledMetrics[key]; ok && temp {
|
||||
delete(disabledMetrics, key)
|
||||
}
|
||||
disabledMetricsMu.Unlock()
|
||||
}
|
||||
for _, key := range reply.DisableKeys {
|
||||
disabledMetricsMu.Lock()
|
||||
disabledMetrics[key] = true // all remotely-disabled keys are "temporarily" disabled
|
||||
disabledMetricsMu.Unlock()
|
||||
}
|
||||
|
||||
// make sure we didn't send the update too soon; if so,
|
||||
// just wait and try again -- this is a special case of
|
||||
// error that we handle differently, as you can see
|
||||
if resp.StatusCode == http.StatusTooManyRequests {
|
||||
if reply.NextUpdate <= 0 {
|
||||
raStr := resp.Header.Get("Retry-After")
|
||||
if ra, err := strconv.Atoi(raStr); err == nil {
|
||||
reply.NextUpdate = time.Duration(ra) * time.Second
|
||||
}
|
||||
}
|
||||
if !final {
|
||||
log.Printf("[NOTICE] Sending telemetry: we were too early; waiting %s before trying again", reply.NextUpdate)
|
||||
time.Sleep(reply.NextUpdate)
|
||||
continue
|
||||
}
|
||||
} else if resp.StatusCode >= 400 {
|
||||
err = fmt.Errorf("telemetry server returned status code %d", resp.StatusCode)
|
||||
continue
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
if err == nil && !final {
|
||||
// (remember, if there was an error, we return it
|
||||
// below, so it WILL get logged if it's supposed to)
|
||||
log.Println("[INFO] Sending telemetry: success")
|
||||
}
|
||||
|
||||
// even if there was an error after all retries, we should
|
||||
// schedule the next update using our default update
|
||||
// interval because the server might be healthy later
|
||||
|
||||
// ensure we won't slam the telemetry server; add a little variance
|
||||
if reply.NextUpdate < 1*time.Second {
|
||||
reply.NextUpdate = defaultUpdateInterval + time.Duration(rand.Int63n(int64(1*time.Minute)))
|
||||
}
|
||||
|
||||
// schedule the next update (if this wasn't the last one and
|
||||
// if the remote server didn't tell us to stop sending)
|
||||
if !final && !reply.Stop {
|
||||
updateTimerMu.Lock()
|
||||
updateTimer = time.AfterFunc(reply.NextUpdate, func() {
|
||||
logEmit(false)
|
||||
})
|
||||
updateTimerMu.Unlock()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func stopUpdateTimer() {
|
||||
updateTimerMu.Lock()
|
||||
updateTimer.Stop()
|
||||
updateTimer = nil
|
||||
updateTimerMu.Unlock()
|
||||
}
|
||||
|
||||
// setEmitTimeMetrics sets some metrics that should
|
||||
// be recorded just before emitting.
|
||||
func setEmitTimeMetrics() {
|
||||
Set("goroutines", runtime.NumGoroutine())
|
||||
|
||||
var mem runtime.MemStats
|
||||
runtime.ReadMemStats(&mem)
|
||||
SetNested("memory", "heap_alloc", mem.HeapAlloc)
|
||||
SetNested("memory", "sys", mem.Sys)
|
||||
}
|
||||
|
||||
// makePayloadAndResetBuffer prepares a payload
|
||||
// by emptying the collection buffer. It returns
|
||||
// the bytes of the payload to send to the server.
|
||||
// Since the buffer is reset by this, if the
|
||||
// resulting byte slice is lost, the payload is
|
||||
// gone with it.
|
||||
func makePayloadAndResetBuffer() ([]byte, error) {
|
||||
bufCopy := resetBuffer()
|
||||
|
||||
// encode payload in preparation for transmission
|
||||
payload := Payload{
|
||||
InstanceID: instanceUUID.String(),
|
||||
Timestamp: time.Now().UTC(),
|
||||
Data: bufCopy,
|
||||
}
|
||||
return json.Marshal(payload)
|
||||
}
|
||||
|
||||
// resetBuffer makes a local pointer to the buffer,
|
||||
// then resets the buffer by assigning to be a newly-
|
||||
// made value to clear it out, then sets the buffer
|
||||
// item count to 0. It returns the copied pointer to
|
||||
// the original map so the old buffer value can be
|
||||
// used locally.
|
||||
func resetBuffer() map[string]interface{} {
|
||||
bufferMu.Lock()
|
||||
bufCopy := buffer
|
||||
buffer = make(map[string]interface{})
|
||||
bufferItemCount = 0
|
||||
bufferMu.Unlock()
|
||||
return bufCopy
|
||||
}
|
||||
|
||||
// Response contains the body of a response from the
|
||||
// telemetry server.
|
||||
type Response struct {
|
||||
// NextUpdate is how long to wait before the next update.
|
||||
NextUpdate time.Duration `json:"next_update"`
|
||||
|
||||
// Stop instructs the telemetry server to stop sending
|
||||
// telemetry. This would only be done under extenuating
|
||||
// circumstances, but we are prepared for it nonetheless.
|
||||
Stop bool `json:"stop,omitempty"`
|
||||
|
||||
// Error will be populated with an error message, if any.
|
||||
// This field should be empty if the status code is < 400.
|
||||
Error string `json:"error,omitempty"`
|
||||
|
||||
// DisableKeys will contain a list of keys/metrics that
|
||||
// should NOT be sent until further notice. The client
|
||||
// must NOT store these items in its buffer or send them
|
||||
// to the telemetry server while they are disabled. If
|
||||
// this list and EnableKeys have the same value (which is
|
||||
// not supposed to happen), this field should dominate.
|
||||
DisableKeys []string `json:"disable_keys,omitempty"`
|
||||
|
||||
// EnableKeys will contain a list of keys/metrics that
|
||||
// MAY be sent until further notice.
|
||||
EnableKeys []string `json:"enable_keys,omitempty"`
|
||||
}
|
||||
|
||||
// Payload is the data that gets sent to the telemetry server.
|
||||
type Payload struct {
|
||||
// The universally unique ID of the instance
|
||||
InstanceID string `json:"instance_id"`
|
||||
|
||||
// The UTC timestamp of the transmission
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
|
||||
// The timestamp before which the next update is expected
|
||||
// (NOT populated by client - the server fills this in
|
||||
// before it stores the data)
|
||||
ExpectNext time.Time `json:"expect_next,omitempty"`
|
||||
|
||||
// The metrics
|
||||
Data map[string]interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// Int returns the value of the data keyed by key
|
||||
// if it is an integer; otherwise it returns 0.
|
||||
func (p Payload) Int(key string) int {
|
||||
val, _ := p.Data[key]
|
||||
switch p.Data[key].(type) {
|
||||
case int:
|
||||
return val.(int)
|
||||
case float64: // after JSON-decoding, int becomes float64...
|
||||
return int(val.(float64))
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// countingSet implements a set that counts how many
|
||||
// times a key is inserted. It marshals to JSON in a
|
||||
// way such that keys are converted to values next
|
||||
// to their associated counts.
|
||||
type countingSet map[interface{}]int
|
||||
|
||||
// MarshalJSON implements the json.Marshaler interface.
|
||||
// It converts the set to an array so that the values
|
||||
// are JSON object values instead of keys, since keys
|
||||
// are difficult to query in databases.
|
||||
func (s countingSet) MarshalJSON() ([]byte, error) {
|
||||
type Item struct {
|
||||
Value interface{} `json:"value"`
|
||||
Count int `json:"count"`
|
||||
}
|
||||
var list []Item
|
||||
|
||||
for k, v := range s {
|
||||
list = append(list, Item{Value: k, Count: v})
|
||||
}
|
||||
|
||||
return json.Marshal(list)
|
||||
}
|
||||
|
||||
var (
|
||||
// httpClient should be used for HTTP requests. It
|
||||
// is configured with a timeout for reliability.
|
||||
httpClient = http.Client{
|
||||
Transport: &http.Transport{
|
||||
TLSHandshakeTimeout: 30 * time.Second,
|
||||
DisableKeepAlives: true,
|
||||
},
|
||||
Timeout: 1 * time.Minute,
|
||||
}
|
||||
|
||||
// buffer holds the data that we are building up to send.
|
||||
buffer = make(map[string]interface{})
|
||||
bufferItemCount = 0
|
||||
bufferMu sync.RWMutex // protects both the buffer and its count
|
||||
|
||||
// updating is used to ensure only one
|
||||
// update happens at a time.
|
||||
updating bool
|
||||
updateMu sync.Mutex
|
||||
|
||||
// updateTimer fires off the next update.
|
||||
// If no update is scheduled, this is nil.
|
||||
updateTimer *time.Timer
|
||||
updateTimerMu sync.Mutex
|
||||
|
||||
// disabledMetrics is a set of metric keys
|
||||
// that should NOT be saved to the buffer
|
||||
// or sent to the telemetry server. The value
|
||||
// indicates whether the entry is temporary.
|
||||
// If the value is true, it may be removed if
|
||||
// the metric is re-enabled remotely later. If
|
||||
// the value is false, it is permanent
|
||||
// (presumably becaues the user explicitly
|
||||
// disabled it) and can only be re-enabled
|
||||
// with user consent.
|
||||
disabledMetrics = make(map[string]bool)
|
||||
disabledMetricsMu sync.RWMutex
|
||||
|
||||
// instanceUUID is the ID of the current instance.
|
||||
// This MUST be set to emit telemetry.
|
||||
// This MUST NOT be openly exposed to clients, for privacy.
|
||||
instanceUUID uuid.UUID
|
||||
|
||||
// enabled indicates whether the package has
|
||||
// been initialized and can be actively used.
|
||||
enabled bool
|
||||
|
||||
// maxBufferItems is the maximum number of items we'll allow
|
||||
// in the buffer before we start dropping new ones, in a
|
||||
// rough (simple) attempt to keep memory use under control.
|
||||
maxBufferItems = 100000
|
||||
)
|
||||
|
||||
const (
|
||||
// endpoint is the base URL to remote telemetry server;
|
||||
// the instance ID will be appended to it.
|
||||
endpoint = "https://telemetry.caddyserver.com/v1/update/"
|
||||
|
||||
// defaultUpdateInterval is how long to wait before emitting
|
||||
// more telemetry data if all retires fail. This value is
|
||||
// only used if the client receives a nonsensical value, or
|
||||
// doesn't send one at all, or if a connection can't be made,
|
||||
// likely indicating a problem with the server. Thus, this
|
||||
// value should be a long duration to help alleviate extra
|
||||
// load on the server.
|
||||
defaultUpdateInterval = 1 * time.Hour
|
||||
)
|
||||
@@ -0,0 +1,59 @@
|
||||
// Copyright 2015 Light Code Labs, LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package telemetry
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMakePayloadAndResetBuffer(t *testing.T) {
|
||||
reset()
|
||||
id := doInit(t)
|
||||
|
||||
buffer = map[string]interface{}{
|
||||
"foo1": "bar1",
|
||||
"foo2": "bar2",
|
||||
}
|
||||
bufferItemCount = 2
|
||||
|
||||
payloadBytes, err := makePayloadAndResetBuffer()
|
||||
if err != nil {
|
||||
t.Fatalf("Error making payload bytes: %v", err)
|
||||
}
|
||||
|
||||
if len(buffer) != 0 {
|
||||
t.Errorf("Expected buffer len to be 0, got %d", len(buffer))
|
||||
}
|
||||
if bufferItemCount != 0 {
|
||||
t.Errorf("Expected buffer item count to be 0, got %d", bufferItemCount)
|
||||
}
|
||||
|
||||
var payload Payload
|
||||
err = json.Unmarshal(payloadBytes, &payload)
|
||||
if err != nil {
|
||||
t.Fatalf("Error deserializing payload: %v", err)
|
||||
}
|
||||
|
||||
if payload.InstanceID != id.String() {
|
||||
t.Errorf("Expected instance ID to be set to '%s' but got '%s'", testUUID, payload.InstanceID)
|
||||
}
|
||||
if payload.Data == nil {
|
||||
t.Errorf("Expected data to be set, but was nil")
|
||||
}
|
||||
if payload.Timestamp.IsZero() {
|
||||
t.Errorf("Expected timestamp to be set, but was zero value")
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user