1   /*
2    * SPDX-FileCopyrightText: none
3    * SPDX-License-Identifier: CC0-1.0
4    */
5   
6   package gov.nist.secauto.oscal.lib.model.util;
7   
8   import gov.nist.secauto.metaschema.core.metapath.DynamicContext;
9   import gov.nist.secauto.metaschema.core.metapath.ISequence;
10  import gov.nist.secauto.metaschema.core.metapath.MetapathExpression;
11  import gov.nist.secauto.metaschema.core.metapath.StaticContext;
12  import gov.nist.secauto.metaschema.core.metapath.item.node.AbstractRecursionPreventingNodeItemVisitor;
13  import gov.nist.secauto.metaschema.core.metapath.item.node.IAssemblyInstanceGroupedNodeItem;
14  import gov.nist.secauto.metaschema.core.metapath.item.node.IAssemblyNodeItem;
15  import gov.nist.secauto.metaschema.core.metapath.item.node.IDefinitionNodeItem;
16  import gov.nist.secauto.metaschema.core.metapath.item.node.IFieldNodeItem;
17  import gov.nist.secauto.metaschema.core.metapath.item.node.IFlagNodeItem;
18  import gov.nist.secauto.metaschema.core.metapath.item.node.IModuleNodeItem;
19  import gov.nist.secauto.metaschema.core.metapath.item.node.INodeItemFactory;
20  import gov.nist.secauto.metaschema.core.model.IModule;
21  import gov.nist.secauto.metaschema.core.model.constraint.IAllowedValuesConstraint;
22  
23  import java.util.Collection;
24  import java.util.LinkedHashMap;
25  import java.util.LinkedList;
26  import java.util.List;
27  import java.util.Map;
28  
29  import edu.umd.cs.findbugs.annotations.NonNull;
30  
31  public class AllowedValueCollectingNodeItemVisitor
32      extends AbstractRecursionPreventingNodeItemVisitor<DynamicContext, Void> {
33  
34    private final Map<IDefinitionNodeItem<?, ?>, NodeItemRecord> nodeItemAnalysis = new LinkedHashMap<>();
35  
36    public Collection<NodeItemRecord> getAllowedValueLocations() {
37      return nodeItemAnalysis.values();
38    }
39  
40    public void visit(@NonNull IModule module) {
41      DynamicContext context = new DynamicContext(
42          StaticContext.builder()
43              .defaultModelNamespace(module.getXmlNamespace())
44              .build());
45      context.disablePredicateEvaluation();
46  
47      visit(INodeItemFactory.instance().newModuleNodeItem(module), context);
48    }
49  
50    public void visit(@NonNull IModuleNodeItem module, @NonNull DynamicContext context) {
51  
52      visitMetaschema(module, context);
53    }
54  
55    private void handleAllowedValuesAtLocation(@NonNull IDefinitionNodeItem<?, ?> itemLocation, DynamicContext context) {
56      itemLocation.getDefinition().getAllowedValuesConstraints().stream()
57          .forEachOrdered(allowedValues -> {
58            String metapath = allowedValues.getTarget();
59  
60            MetapathExpression path = MetapathExpression.compile(metapath, context.getStaticContext());
61            ISequence<?> result = path.evaluate(itemLocation, context);
62            result.stream().forEachOrdered(target -> {
63              handleAllowedValues(allowedValues, itemLocation, (IDefinitionNodeItem<?, ?>) target);
64            });
65          });
66    }
67  
68    private void handleAllowedValues(
69        @NonNull IAllowedValuesConstraint allowedValues,
70        @NonNull IDefinitionNodeItem<?, ?> location,
71        @NonNull IDefinitionNodeItem<?, ?> target) {
72      NodeItemRecord itemRecord = nodeItemAnalysis.get(target);
73      if (itemRecord == null) {
74        itemRecord = new NodeItemRecord(target);
75        nodeItemAnalysis.put(target, itemRecord);
76      }
77  
78      AllowedValuesRecord allowedValuesRecord = new AllowedValuesRecord(allowedValues, location, target);
79      itemRecord.addAllowedValues(allowedValuesRecord);
80    }
81  
82    @Override
83    public Void visitFlag(IFlagNodeItem item, DynamicContext context) {
84      handleAllowedValuesAtLocation(item, context);
85      return super.visitFlag(item, context);
86    }
87  
88    @Override
89    public Void visitField(IFieldNodeItem item, DynamicContext context) {
90      handleAllowedValuesAtLocation(item, context);
91      return super.visitField(item, context);
92    }
93  
94    @Override
95    public Void visitAssembly(IAssemblyNodeItem item, DynamicContext context) {
96      handleAllowedValuesAtLocation(item, context);
97  
98      return super.visitAssembly(item, context);
99    }
100 
101   @Override
102   public Void visitAssembly(IAssemblyInstanceGroupedNodeItem item, DynamicContext context) {
103     return visitAssembly((IAssemblyNodeItem) item, context);
104   }
105 
106   @Override
107   protected Void defaultResult() {
108     return null;
109   }
110 
111   public static final class NodeItemRecord {
112     @NonNull
113     private final IDefinitionNodeItem<?, ?> item;
114     @NonNull
115     private final List<AllowedValuesRecord> allowedValues = new LinkedList<>();
116 
117     private NodeItemRecord(@NonNull IDefinitionNodeItem<?, ?> item) {
118       this.item = item;
119     }
120 
121     @NonNull
122     public IDefinitionNodeItem<?, ?> getItem() {
123       return item;
124     }
125 
126     @NonNull
127     public List<AllowedValuesRecord> getAllowedValues() {
128       return allowedValues;
129     }
130 
131     public void addAllowedValues(@NonNull AllowedValuesRecord record) {
132       this.allowedValues.add(record);
133     }
134   }
135 
136   public static final class AllowedValuesRecord {
137     @NonNull
138     private final IAllowedValuesConstraint allowedValues;
139     @NonNull
140     private final IDefinitionNodeItem<?, ?> location;
141     @NonNull
142     private final IDefinitionNodeItem<?, ?> target;
143 
144     public AllowedValuesRecord(
145         @NonNull IAllowedValuesConstraint allowedValues,
146         @NonNull IDefinitionNodeItem<?, ?> location,
147         @NonNull IDefinitionNodeItem<?, ?> target) {
148       this.allowedValues = allowedValues;
149       this.location = location;
150       this.target = target;
151     }
152 
153     @NonNull
154     public IAllowedValuesConstraint getAllowedValues() {
155       return allowedValues;
156     }
157 
158     @NonNull
159     public IDefinitionNodeItem<?, ?> getLocation() {
160       return location;
161     }
162 
163     @NonNull
164     public IDefinitionNodeItem<?, ?> getTarget() {
165       return target;
166     }
167   }
168 }