commit 0401b66150477f5fe53a63dfdaedc23dff5e6501
parent 0ca7bfa1adc0d953ac2ef4d88752ee23a1896928
Author: Lukas Henkel <lh@entf.net>
Date: Tue, 16 Feb 2021 21:35:53 +0100
Refactored functionality for use as a library
Diffstat:
5 files changed, 149 insertions(+), 70 deletions(-)
diff --git a/cmd/htmlattr/main.go b/cmd/htmlattr/main.go
@@ -20,39 +20,23 @@ func main() {
os.Exit(1)
}
attrs := strings.Split(args[0], fs)
- for i, attr := range attrs {
- attrs[i] = strings.ToLower(attr)
- }
htmltools.Main(args[1:], func(doc *html.Node) {
- var body *html.Node
- for n := doc.FirstChild.FirstChild; n != nil; n = n.NextSibling {
- if strings.ToLower(n.Data) == "body" {
- body = n
- break
- }
+ body, err := htmltools.Body(doc)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "%v\n", err)
+ os.Exit(1)
+ } else if body == nil {
+ fmt.Fprintln(os.Stderr, "Document does not contain a body")
+ os.Exit(1)
}
- if body == nil {
- fmt.Fprintln(os.Stderr, "document does not contain a body")
+ values, err := htmltools.Attr(attrs, htmltools.Children(doc)...)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "%v\n", err)
os.Exit(1)
}
- for n := body.FirstChild; n != nil; n = n.NextSibling {
- if n.Type != html.ElementNode {
- continue
- }
- list := make([]string, len(attrs))
- var any bool
- for i, attrn := range attrs {
- for _, attr := range n.Attr {
- if strings.ToLower(attr.Key) == attrn {
- any = true
- list[i] = attr.Val
- }
- }
- }
- line := strings.Join(list, fs)
- if any {
- fmt.Println(line)
- }
+ for _, v := range values {
+ line := strings.Join(v, fs)
+ fmt.Println(line)
}
})
}
diff --git a/cmd/htmlindentheadings/main.go b/cmd/htmlindentheadings/main.go
@@ -24,28 +24,9 @@ func main() {
os.Exit(1)
}
htmltools.Main(args[1:], func(doc *html.Node) {
- visit(lvls, doc)
+ for node := range htmltools.FindRecursive(doc, nil) {
+ htmltools.IndentHeadings(lvls, node)
+ }
html.Render(os.Stdout, doc)
})
}
-
-func indent(lvls int, tag string) string {
- l := int(tag[1]) - 48
- l += lvls
- if l > 6 {
- l = 6
- }
- return fmt.Sprintf("h%d", l)
-}
-
-func visit(lvls int, n *html.Node) {
- if n.Type == html.ElementNode {
- switch n.Data {
- case "h1", "h2", "h3", "h4", "h5", "h6":
- n.Data = indent(lvls, n.Data)
- }
- }
- for c := n.FirstChild; c != nil; c = c.NextSibling {
- visit(lvls, c)
- }
-}
diff --git a/cmd/htmltotext/main.go b/cmd/htmltotext/main.go
@@ -11,16 +11,13 @@ import (
)
func main() {
- htmltools.Main(os.Args[1:], visit)
-}
-
-func visit(n *html.Node) {
- if n.Type == html.TextNode {
- if t := strings.TrimSpace(n.Data); t != "" {
- fmt.Println(t)
+ htmltools.Main(os.Args[1:], func(doc *html.Node) {
+ for n := range htmltools.FindRecursive(
+ doc,
+ htmltools.MatchNodeTypeFunc(html.TextNode)) {
+ if t := strings.TrimSpace(n.Data); t != "" {
+ fmt.Println(t)
+ }
}
- }
- for c := n.FirstChild; c != nil; c = c.NextSibling {
- visit(c)
- }
+ })
}
diff --git a/cmd/htmlunwrap/main.go b/cmd/htmlunwrap/main.go
@@ -28,15 +28,10 @@ func main() {
func unwrap(sel cascadia.Selector, doc *html.Node) {
for _, n := range sel.MatchAll(doc) {
- cs := make([]*html.Node, 0)
- for c := n.FirstChild; c != nil; c = c.NextSibling {
- cs = append(cs, c)
+ if err := htmltools.Unwrap(n); err != nil {
+ fmt.Fprintf(os.Stderr, "%v\n", err)
+ os.Exit(1)
}
- for _, c := range cs {
- n.RemoveChild(c)
- n.Parent.InsertBefore(c, n)
- }
- n.Parent.RemoveChild(n)
}
html.Render(os.Stdout, doc)
}
diff --git a/htmltools.go b/htmltools.go
@@ -0,0 +1,122 @@
+package htmltools
+
+import (
+ "errors"
+ "fmt"
+ "strings"
+
+ "golang.org/x/net/html"
+)
+
+var (
+ ErrNodeIsNotADocumentNode = errors.New("Not a document node")
+ ErrNodeHasNoParent = errors.New("Node has no parent")
+)
+
+type NodeMatchFunc func(*html.Node) bool
+
+// Gets the body from an HTML document node.
+func Body(doc *html.Node) (*html.Node, error) {
+ if doc.Type != html.DocumentNode {
+ return nil, ErrNodeIsNotADocumentNode
+ }
+ var body *html.Node
+ for n := doc.FirstChild.FirstChild; n != nil; n = n.NextSibling {
+ if strings.ToLower(n.Data) == "body" {
+ body = n
+ break
+ }
+ }
+ return body, nil
+}
+
+// Gets all direct children.
+func Children(node *html.Node) []*html.Node {
+ nodes := make([]*html.Node, 0)
+ for n := node.FirstChild; n != nil; n = n.NextSibling {
+ nodes = append(nodes, n)
+ }
+ return nodes
+}
+
+func findRecursive(node *html.Node, nodeFunc func(*html.Node) bool, ch chan<- *html.Node) {
+ if nodeFunc == nil || nodeFunc(node) {
+ ch <- node
+ }
+ for _, c := range Children(node) {
+ findRecursive(c, nodeFunc, ch)
+ }
+}
+
+// Returns a channel providing all nodes that match nodeFunc recursively through
+// the whole document. If nodeFunc is `nil`, all nodes match.
+func FindRecursive(doc *html.Node, nodeFunc NodeMatchFunc) <-chan *html.Node {
+ ch := make(chan *html.Node)
+ go findRecursive(doc, nodeFunc, ch)
+ return ch
+}
+
+// Returns all attribite values specified in attrs for nodes.
+func Attr(attrs []string, nodes ...*html.Node) ([][]string, error) {
+ for i, attr := range attrs {
+ attrs[i] = strings.ToLower(attr)
+ }
+ results := make([][]string, 0)
+ for _, n := range nodes {
+ if n.Type != html.ElementNode {
+ continue
+ }
+ list := make([]string, len(attrs))
+ var any bool
+ for i, attrn := range attrs {
+ for _, attr := range n.Attr {
+ if strings.ToLower(attr.Key) == attrn {
+ any = true
+ list[i] = attr.Val
+ }
+ }
+ }
+ if any {
+ results = append(results, list)
+ }
+ }
+ return results, nil
+}
+
+// Indents all headings by a certain level.
+func IndentHeadings(level int, nodes ...*html.Node) error {
+ for _, n := range nodes {
+ switch strings.ToLower(n.Data) {
+ case "h1", "h2", "h3", "h4", "h5", "h6":
+ default:
+ continue
+ }
+ l := int(n.Data[1]) - 48 //HACK: ASCII to number
+ l += level
+ if l > 6 {
+ l = 6
+ }
+ n.Data = fmt.Sprintf("h%d", l)
+ }
+ return nil
+}
+
+// Removes node from parent and replaces it by it's children.
+func Unwrap(node *html.Node) error {
+ if node.Parent == nil {
+ return ErrNodeHasNoParent
+ }
+ for _, c := range Children(node) {
+ node.RemoveChild(c)
+ node.Parent.InsertBefore(c, node)
+ }
+ node.Parent.RemoveChild(node)
+ return nil
+}
+
+// Creates a NodeMatchFunc, matching a certain NodeType
+func MatchNodeTypeFunc(nodeType html.NodeType) NodeMatchFunc {
+ return func(node *html.Node) bool {
+ return node.Type == nodeType
+ }
+}